MLIR 23.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
14
25#include "mlir/IR/Builders.h"
27#include "mlir/IR/IRMapping.h"
28#include "mlir/IR/Matchers.h"
32#include "mlir/IR/Value.h"
34
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SetVector.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/Support/InterleavedRange.h"
40
41#include "llvm/Support/DebugLog.h"
42#include <numeric>
43#include <optional>
44#include <tuple>
45#define DEBUG_TYPE "parallel-loop-fusion"
46
47namespace mlir {
48#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
49#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
50} // namespace mlir
51
52using namespace mlir;
53using namespace mlir::scf;
54
55/// Verify there are no nested ParallelOps.
56static bool hasNestedParallelOp(ParallelOp ploop) {
57 auto walkResult =
58 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
59 return walkResult.wasInterrupted();
60}
61
62/// Verify equal iteration spaces.
63static bool equalIterationSpaces(ParallelOp firstPloop,
64 ParallelOp secondPloop) {
65 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
66 return false;
67
68 auto matchOperands = [&](const OperandRange &lhs,
69 const OperandRange &rhs) -> bool {
70 // TODO: Extend this to support aliases and equal constants.
71 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
72 };
73 return matchOperands(firstPloop.getLowerBound(),
74 secondPloop.getLowerBound()) &&
75 matchOperands(firstPloop.getUpperBound(),
76 secondPloop.getUpperBound()) &&
77 matchOperands(firstPloop.getStep(), secondPloop.getStep());
78}
79
80/// Check if both operations are the same type of memory write op and
81/// write to the same memory location (same buffer and same indices).
83 if (!op1 || !op2 || op1->getName() != op2->getName())
84 return false;
85 if (op1 == op2)
86 return true;
87 // support only these memory-writing ops for now
88 if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
89 return false;
90 bool opsAreIdentical =
92 .Case([&](memref::StoreOp storeOp1) {
93 auto storeOp2 = cast<memref::StoreOp>(op2);
94 return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
95 (storeOp1.getIndices() == storeOp2.getIndices());
96 })
97 .Case([&](vector::TransferWriteOp writeOp1) {
98 auto writeOp2 = cast<vector::TransferWriteOp>(op2);
99 return (writeOp1.getBase() == writeOp2.getBase()) &&
100 (writeOp1.getIndices() == writeOp2.getIndices()) &&
101 (writeOp1.getMask() == writeOp2.getMask()) &&
102 (writeOp1.getValueToStore().getType() ==
103 writeOp2.getValueToStore().getType()) &&
104 (writeOp1.getInBounds() == writeOp2.getInBounds());
105 })
106 .Case([&](vector::StoreOp vecStoreOp1) {
107 auto vecStoreOp2 = cast<vector::StoreOp>(op2);
108 return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
109 (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
110 (vecStoreOp1.getValueToStore().getType() ==
111 vecStoreOp2.getValueToStore().getType()) &&
112 (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
113 (vecStoreOp1.getNontemporal() ==
114 vecStoreOp2.getNontemporal());
115 })
116 .Default([](Operation *) { return false; });
117 return opsAreIdentical;
118}
119
120/// Check if val1 (from the first parallel loop) and val2 (from the
121/// second) are equivalent, considering the mapping of induction variables from
122/// the first to the second parallel loop.
123static bool valsAreEquivalent(Value val1, Value val2,
124 const IRMapping &loopsIVsMap) {
125 if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
126 loopsIVsMap.lookupOrDefault(val2) == val1)
127 return true;
128 Operation *val1DefOp = val1.getDefiningOp();
129 Operation *val2DefOp = val2.getDefiningOp();
130 if (!val1DefOp || !val2DefOp)
131 return false;
132 if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
133 return false;
135 val1DefOp, val2DefOp,
136 [&](Value v1, Value v2) {
137 return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
138 loopsIVsMap.lookupOrDefault(v2) == v1);
139 },
140 /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
141}
142
143/// If the `expr` value is the result of an integer addition of `base` and a
144/// constant, return the constant.
145static std::optional<int64_t> getAddConstant(Value expr, Value base,
146 const IRMapping &loopsIVsMap) {
147 if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
148 if (auto constOp = getConstantIntValue(addOp.getLhs());
149 constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
150 return constOp.value();
151 if (auto constOp = getConstantIntValue(addOp.getRhs());
152 constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
153 return constOp.value();
154 return std::nullopt;
155 }
156
157 if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
158 if (auto constOp = getConstantIntValue(addOp.getLhs());
159 constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
160 return constOp.value();
161 if (auto constOp = getConstantIntValue(addOp.getRhs());
162 constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
163 return constOp.value();
164 return std::nullopt;
165 }
166
167 if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
168 AffineMap map = applyOp.getAffineMap();
169 if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
170 map.getNumSymbols() != 0)
171 return std::nullopt;
172 if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
173 return std::nullopt;
174 AffineExpr result = map.getResult(0);
175 auto bin = dyn_cast<AffineBinaryOpExpr>(result);
176 if (!bin || bin.getKind() != AffineExprKind::Add)
177 return std::nullopt;
178 auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
179 auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
180 auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
181 auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
182 if (lhsConst && rhsDim)
183 return lhsConst.getValue();
184 if (rhsConst && lhsDim)
185 return rhsConst.getValue();
186 }
187 return std::nullopt;
188}
189
190// Return true if the scalar load index may hit any element covered by a
191// vector.store/transfer_write along a single memref dimension. Supported cases:
192//
193// 1) Direct index match (with optional offset):
194// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
195// %x = memref.load %A[%i] : memref<...>
196//
197// 2) Loop IV range intersects the write range:
198// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
199// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
200//
201// 3) Constant index (or IV + constant) within the write range:
202// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
203// %x = memref.load %A[%c2] : memref<...>
204// %y = memref.load %A[%i + %c1] : memref<...>
205//
206// Args:
207// - loadIndex: index used by the scalar load for this dimension.
208// - offset: subview offset for the base memref dimension (if any).
209// - writeIndex: index used by the transfer_write for this dimension. Can be
210// null if the dim was dropped by a rank reducing subview, whose result is
211// written by the vector.write.
212// - extent: vector size along this dimension (number of elements written).
213// - loopsIVsMap: IV equivalence map between fused loops.
214static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
215 Value writeIndex, int64_t extent,
216 const IRMapping &loopsIVsMap) {
217 if (extent <= 0)
218 return false;
219
220 // Extract constant loop bounds for loop IVs (e.g. from scf.for).
221 auto getConstLoopBoundsForIV =
222 [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
223 auto blockArg = dyn_cast<BlockArgument>(index);
224 if (!blockArg)
225 return std::nullopt;
226 auto *parentOp = blockArg.getOwner()->getParentOp();
227 auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
228 if (!loopLike)
229 return std::nullopt;
230 auto ranges = getConstLoopBounds(loopLike);
231 if (ranges.empty())
232 return std::nullopt;
233
234 auto ivs = loopLike.getLoopInductionVars();
235 if (!ivs)
236 return std::nullopt;
237 auto it = llvm::find(*ivs, blockArg);
238 if (it == ivs->end())
239 return std::nullopt;
240 unsigned pos = std::distance(ivs->begin(), it);
241 if (pos >= ranges.size())
242 return std::nullopt;
243 auto [lb, ub, step] = ranges[pos];
244 return std::make_tuple(lb, ub, step);
245 };
246
247 std::optional<int64_t> offsetConst = getConstantIntValue(offset);
248 std::optional<int64_t> writeConst =
249 writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
250 if (!writeConst && writeIndex) {
251 // Treat single-iteration IVs as constants for matching.
252 if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
253 auto [lb, ub, step] = *bounds;
254 if (step > 0 && ub == lb + step)
255 writeConst = lb;
256 }
257 }
258
259 // Check whether a loop IV is fully contained in a constant write range.
260 auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
261 int64_t rangeStart, int64_t rangeExtent) -> bool {
262 if (rangeExtent <= 0 || step <= 0)
263 return false;
264 if (ub <= lb)
265 return false;
266 int64_t rangeEnd = rangeStart + rangeExtent;
267 return lb >= rangeStart && ub <= rangeEnd;
268 };
269
270 if (offsetConst && writeConst) {
271 // Constant start of the write range; check constant load or loop IV range.
272 int64_t start = *offsetConst + *writeConst;
273 if (auto loadConst = getConstantIntValue(loadIndex))
274 return (*loadConst >= start && *loadConst < start + extent);
275 if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
276 auto [lb, ub, step] = *bounds;
277 return loopIVWithinRange(lb, ub, step, start, extent);
278 }
279 }
280
281 if (writeIndex) {
282 // Direct IV match (or IV + constant) against the write index.
283 if (offsetConst && *offsetConst == 0 &&
284 valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
285 return true;
286 if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
287 // Match load index of the form writeIndex + C within the write extent.
288 if (offsetConst) {
289 int64_t start = *offsetConst;
290 return (*addConst >= start && *addConst < start + extent);
291 }
292 }
293 return false;
294 }
295
296 if (auto offsetVal = dyn_cast<Value>(offset)) {
297 // Exact match when extent is 1 and the load hits the offset value.
298 if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
299 return true;
300 }
301
302 return false;
303}
304
305/// Return the base memref value used by the given memory op.
307 // TODO: use the common interface for memory ops once available.
309 .Case([&](memref::LoadOp load) { return load.getMemRef(); })
310 .Case([&](memref::StoreOp store) { return store.getMemRef(); })
311 .Case([&](vector::TransferReadOp read) { return read.getBase(); })
312 .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
313 .Case([&](vector::LoadOp load) { return load.getBase(); })
314 .Case([&](vector::StoreOp store) { return store.getBase(); })
315 .Default([](Operation *) { return Value(); });
316}
317
318/// Recognize scalar memref.load of an element produced by a vector write
319/// (vector.transfer_write or vector.store, optionally through a rank-reducing
320/// unit-stride subview) of the same buffer. This covers the pattern where a
321/// vector write stores a full lane pack and a subsequent scalar load reads an
322/// element from that lane pack. EXAMPLE:
323/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
324/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
325/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
326/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
327/// ...
328/// }
329///
330static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
331 ValueRange writeIndices, VectorType vecTy,
332 ArrayRef<int64_t> vectorDimForWriteDim,
333 const IRMapping &ivsMap) {
334 if (!vecTy)
335 return false;
336
337 Value base = writeBase;
338 // The write base if there is no subview, or the subview source otherwise.
339 MemrefValue baseMemref = nullptr;
341 llvm::SmallBitVector droppedDims;
342 bool hasSubview = false;
343 auto *ctx = loadOp.getContext();
344 if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
345 if (!subView.hasUnitStride())
346 return false;
347 baseMemref = cast<MemrefValue>(subView.getSource());
348 offsets = llvm::to_vector(subView.getMixedOffsets());
349 droppedDims = subView.getDroppedDims();
350 hasSubview = true;
351 } else {
352 baseMemref = dyn_cast<MemrefValue>(base);
353 if (!baseMemref)
354 return false;
355 }
356
357 auto loadIndices = loadOp.getIndices();
358 unsigned baseRank = baseMemref.getType().getRank();
359 if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
360 return false;
361
362 unsigned writeRank = writeIndices.size();
363 if ((!hasSubview && writeRank != baseRank) ||
364 (hasSubview && offsets.size() != baseRank) ||
365 (vectorDimForWriteDim.size() != writeRank))
366 return false;
367
368 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
369 unsigned writeMemrefDim = 0;
370 for (unsigned baseDim : llvm::seq(baseRank)) {
371 bool wasDropped = (hasSubview && droppedDims.test(baseDim));
372 int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
373 int64_t extent = 1;
374 if (vectorDim >= 0) {
375 int64_t dimSize = vecTy.getDimSize(vectorDim);
376 if (dimSize == ShapedType::kDynamic)
377 return false;
378 extent = dimSize;
379 }
380 Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
381 OpFoldResult offset =
382 hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
383 if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
384 extent, ivsMap))
385 return false;
386 if (!wasDropped)
387 ++writeMemrefDim;
388 }
389
390 return true;
391}
392
393/// Recognize scalar memref.load of an element produced by a
394/// vector.transfer_write
395static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
396 vector::TransferWriteOp writeOp,
397 const IRMapping &ivsMap) {
398 auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
399 if (!vecTy)
400 return false;
401
402 unsigned writeRank = writeOp.getIndices().size();
403 AffineMap permutationMap = writeOp.getPermutationMap();
404 if (!permutationMap.isProjectedPermutation() ||
405 permutationMap.getNumResults() != vecTy.getRank() ||
406 permutationMap.getNumDims() != writeRank)
407 return false;
408
409 SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
410 for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
411 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
412 if (!dimExpr)
413 return false;
414 unsigned writeDim = dimExpr.getPosition();
415 if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
416 return false;
417 vectorDimForWriteDim[writeDim] = vecDim;
418 }
419
420 return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
421 vecTy, vectorDimForWriteDim, ivsMap);
422}
423
424/// Recognize scalar memref.load of an element produced by a vector.store
425static bool loadMatchesVectorStore(memref::LoadOp loadOp,
426 vector::StoreOp storeOp,
427 const IRMapping &ivsMap) {
428 auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
429 if (!vecTy)
430 return false;
431
432 unsigned writeRank = storeOp.getIndices().size();
433 if (vecTy.getRank() > writeRank)
434 return false;
435
436 SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
437 unsigned vecRank = vecTy.getRank();
438 for (unsigned i = 0; i < vecRank; ++i) {
439 unsigned writeDim = writeRank - vecRank + i;
440 vectorDimForWriteDim[writeDim] = i;
441 }
442
443 return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
444 vecTy, vectorDimForWriteDim, ivsMap);
445}
446
447/// Check if both operations access the same positions of the same
448/// buffer, but one of the two does it through a rank-reducing full subview of
449/// the buffer (the other's base). EXAMPLE:
450/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
451/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
452/// to memref<2x2xf32>
453/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
454template <typename OpTy1, typename OpTy2>
456 OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
457 OpBuilder &b) {
458 auto base1 = cast<MemrefValue>(getBaseMemref(op1));
459 auto base2 = cast<MemrefValue>(getBaseMemref(op2));
460 if (!base1 || !base2)
461 return false;
462
463 auto accessThroughTrivialSubviewIsSame =
464 [&b](memref::SubViewOp subView, ValueRange subViewAccess,
465 ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
466 SmallVector<Value> resolvedSubviewAccess;
467 LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
468 subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
469 if (failed(resolved) ||
470 (resolvedSubviewAccess.size() != sourceAccess.size()))
471 return false;
472 for (auto [dimIdx, resolvedIndex] :
473 llvm::enumerate(resolvedSubviewAccess)) {
474 if (!matchPattern(resolvedIndex, m_Zero()) &&
475 !valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
476 return false;
477 }
478 return true;
479 };
480
481 // Case 1: op1 uses a subview of op2's base.
482 if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
483 subView &&
485 base2, cast<MemrefValue>(subView.getSource())) &&
486 accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
487 op2.getIndices(),
488 firstToSecondPloopIVsMap))
489 return true;
490
491 // Case 2: op2 uses a subview of op1's base.
492 if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
493 subView &&
495 base1, cast<MemrefValue>(subView.getSource())) &&
496 accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
497 op1.getIndices(),
498 firstToSecondPloopIVsMap))
499 return true;
500
501 return false;
502}
503
504/// Check if both memory read/write operations access the same indices
505/// (considering also the mapping of induction variables from the first to the
506/// second parallel loop).
507template <typename OpTy1, typename OpTy2>
508static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
509 const IRMapping &loopsIVsMap, OpBuilder &b) {
510 auto indices1 = op1.getIndices();
511 auto indices2 = op2.getIndices();
512 if (indices1.size() != indices2.size())
513 return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
514 for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
515 if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
516 return false;
517 }
518 return true;
519}
520
521/// Check if the loadOp reads from the same memory location (same buffer,
522/// same indices and same properties) as written by the storeOp.
523static bool
525 const IRMapping &firstToSecondPloopIVsMap,
526 OpBuilder &b) {
527 if (!loadOp || !storeOp)
528 return false;
529 // Support only these memory-reading ops for now
530 if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
531 return false;
532 bool accessSameMemory =
534 .Case([&](memref::LoadOp memLoadOp) {
535 if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
536 return opsAccessSameIndices(memLoadOp, memStoreOp,
537 firstToSecondPloopIVsMap, b);
538 if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
539 return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
540 firstToSecondPloopIVsMap);
541 if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
542 return loadMatchesVectorStore(memLoadOp, vecStoreOp,
543 firstToSecondPloopIVsMap);
544 return false;
545 })
546 .Case([&](vector::TransferReadOp vecReadOp) {
547 auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
548 if (!vecWriteOp)
549 return false;
550 return opsAccessSameIndices(vecReadOp, vecWriteOp,
551 firstToSecondPloopIVsMap, b) &&
552 (vecReadOp.getMask() == vecWriteOp.getMask()) &&
553 (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
554 })
555 .Case([&](vector::LoadOp vecLoadOp) {
556 auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
557 if (!vecStoreOp)
558 return false;
559 return opsAccessSameIndices(vecLoadOp, vecStoreOp,
560 firstToSecondPloopIVsMap, b) &&
561 (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
562 })
563 .Default([](Operation *) { return false; });
564 return accessSameMemory;
566
569 .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
570 .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
571 .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
572 .Default([](Operation *) { return Value(); });
575/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
576/// aliasing between the loadOp and storeOp can be resolved by analyzing their
577/// access patterns.
578static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
579 const IRMapping &loopsIVsMap) {
580 if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
581 transfWriteOp && isa<memref::LoadOp>(loadOp))
582 return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
583 loopsIVsMap);
584 if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
585 vecStoreOp && isa<memref::LoadOp>(loadOp))
586 return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
587 loopsIVsMap);
588 return false;
589}
590
591/// Check that the parallel loops have no mixed access to the same buffers.
592/// Return `true` if the second parallel loop does not read or write the buffers
593/// written by the first loop using different indices.
595 ParallelOp firstPloop, ParallelOp secondPloop,
596 const IRMapping &firstToSecondPloopIndices,
598 // Map buffers to their store/write ops in the firstPloop
599 DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
600 // Record all the memory buffers used in store/write ops found in firstPloop
601 llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
602
603 auto collectStoreOpsInWalk = [&](Operation *op) {
604 auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
605 // Ignore ops that don't write to memory
606 if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
607 !memOpInterf.hasEffect<MemoryEffects::Free>()))
608 return WalkResult::advance();
609
610 // Only these memory-writing ops are supported for now:
611 // memref.store, vector.transfer_write, vector.store
612 Value storeOpBase = getStoreOpTargetBuffer(op);
613 if (!storeOpBase)
614 return WalkResult::interrupt();
615
616 // Expect the base operand to be a Memref
617 MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
618 if (!storeOpBaseMemref)
619 return WalkResult::interrupt();
620 // Get the original memref buffer, skipping full view-like ops
621 Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
622 bufferStoresInFirstPloop[buffer].push_back(op);
623 buffersWrittenInFirstPloop.insert(buffer);
624 return WalkResult::advance();
625 };
626
627 // Walk the first parallel loop to collect all store/write ops and their
628 // target buffers
629 if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
630 return false;
631
632 // Check that this load/read op encountered while walking the second parallel
633 // loop does not have incompatible data dependencies with the store/write ops
634 // collected from the first parallel loop: the loops can be fused only if in
635 // the 2nd loop there are no loads/stores from/to the buffers written in the
636 // 1st loop, except when on the same exact memory location (same indices) as
637 // written in the 1st loop.
638 auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
639 auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
640 // To be conservative, we should stop on ops that don't advertise their
641 // memory effects. However, many ops don't implement MemoryEffectOpInterface
642 // yet, so for now we just skip them.
643 // TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
644 if (!memOpInterf &&
646 return WalkResult::advance();
647 // Ignore ops that don't read from memory, and wrapping ops that have nested
648 // memory effects (e.g. loops, conditionals) as they will be analyzed when
649 // visiting their nested ops.
650 if ((!memOpInterf &&
652 (memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
653 return WalkResult::advance();
654 // Support only these memory-reading ops for now
655 if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
656 !isa<MemrefValue>(loadOp->getOperand(0)))
657 return WalkResult::interrupt();
658
659 MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
660 MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
661
662 for (Value storedMem : buffersWrittenInFirstPloop)
663 if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
664 !llvm::all_of(bufferStoresInFirstPloop[storedMem],
665 [&](Operation *storeOp) {
666 return canResolveAlias(loadOp, storeOp,
667 firstToSecondPloopIndices);
668 })) {
669 return WalkResult::interrupt();
670 }
671
672 auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
673 if (writeOpsIt == bufferStoresInFirstPloop.end())
674 return WalkResult::advance();
675 // Store/write ops to this buffer in the firstPloop
676 SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
677
678 // If the first loop has no writes to this buffer, continue
679 if (writeOps.empty())
680 return WalkResult::advance();
681
682 Operation *writeOp = writeOps.front();
683
684 // In the first parallel loop, multiple writes to the same memref are
685 // allowed only on the same memory location
686 if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
687 return opsWriteSameMemLocation(writeOp, otherWriteOp);
688 })) {
689 return WalkResult::interrupt();
690 }
691
692 // Check that the load in secondPloop reads from the same memory location as
693 // written by the corresponding store in firstPloop
694 if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
695 firstToSecondPloopIndices, b)) {
696 return WalkResult::interrupt();
697 }
698
699 return WalkResult::advance();
700 };
701
702 // Walk the second parallel loop to check load/read ops against the stores
703 // collected from the first parallel loop.
704 return !secondPloop.getBody()
705 ->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
706 .wasInterrupted();
707}
708
709/// Check that in each loop there are no read ops on the buffers written
710/// by the other loop, except when reading from the same exact memory location
711/// (same indices) as written in the other loop.
712static bool
713noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
714 const IRMapping &firstToSecondPloopIndices,
716 OpBuilder &b) {
718 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
719 return false;
720
721 IRMapping secondToFirstPloopIndices;
722 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
723 firstPloop.getBody()->getArguments());
725 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
726}
727
728/// Check if fusion of the two parallel loops is legal:
729/// i.e. no nested parallel loops, equal iteration spaces,
730/// and no incompatible data dependencies between the loops.
731static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
732 const IRMapping &firstToSecondPloopIndices,
734 OpBuilder &b) {
735 if (hasNestedParallelOp(firstPloop) || hasNestedParallelOp(secondPloop) ||
736 !equalIterationSpaces(firstPloop, secondPloop) ||
737 !noIncompatibleDataDependencies(firstPloop, secondPloop,
738 firstToSecondPloopIndices, mayAlias, b))
739 return false;
740
741 // We are fusing first loop into second, make sure there are no users of the
742 // first loop results between loops.
743 DominanceInfo dom;
744 for (Operation *user : firstPloop->getUsers()) {
745 if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
746 return false;
747 }
748 return true;
749}
750
751// Returns new parallel loop where two loops matching indices param are
752// interchanged
753static std::optional<ParallelOp>
754interchangeLoops(OpBuilder &builder, ParallelOp &loop,
755 const ArrayRef<int64_t> &indices) {
756 assert(loop.getNumLoops() == indices.size());
757 if (loop.getNumLoops() < 2)
758 return std::nullopt;
759
760 // Replace the parallel loop with the same parallel loop.
761 builder.setInsertionPoint(loop);
762 SmallVector<Value> newLB =
763 applyPermutation(SmallVector<Value>(loop.getLowerBound()), indices);
764 SmallVector<Value> newUB =
765 applyPermutation(SmallVector<Value>(loop.getUpperBound()), indices);
766 SmallVector<Value> newStep =
768 auto newOp = ParallelOp::create(builder, loop.getLoc(), newLB, newUB, newStep,
769 loop.getInitVals(), nullptr);
770 auto ivs = loop.getInductionVars();
772 newOp.getInductionVars(), invertPermutationVector(indices));
773 IRMapping mapping;
774 for (auto [iv, riv] : llvm::zip(ivs, newIvs)) {
775 mapping.map(iv, riv);
776 }
777
778 // Copy parallel loop body
779 auto b = OpBuilder::atBlockBegin(newOp.getBody());
780 for (auto &o : loop.getNumReductions()
781 ? loop.getBodyRegion().front()
782 : loop.getBodyRegion().front().without_terminator()) {
783 b.clone(o, mapping);
784 }
785 return newOp;
786}
787
788struct LoopIV {
790 bool operator!=(LoopIV const &other) const { return !(*this == other); }
791 bool operator==(LoopIV const &other) const {
792 return lBound == other.lBound && uBound == other.uBound &&
793 step == other.step;
794 }
795};
796
797template <>
799 static inline bool isEqual(const LoopIV &lhs, const LoopIV &rhs) {
800 return (lhs == rhs);
801 }
802
803 static inline unsigned getHashValue(const LoopIV &val) {
804 return llvm::hash_combine(
808 }
809};
810
811// Returns vector of candidate permutation indices vectors,
812// can be empty. Caps the number of extra candidate permutations
813// explored to avoid combinatorial explosion. This makes the search
814// intentionally incomplete.
817 ParallelOp &secondPloop,
818 int permBudget = 120) {
819 // Check preconditions
820 if (firstPloop.getNumLoops() < 2 ||
821 firstPloop.getNumLoops() != secondPloop.getNumLoops())
822 return {};
823
824 SmallVector<LoopIV> firstIVs(firstPloop.getNumLoops());
825 SmallVector<LoopIV> secondIVs(secondPloop.getNumLoops());
826 llvm::SmallSetVector<LoopIV, 6> unique;
827 for (unsigned index : llvm::seq(firstPloop.getNumLoops())) {
828 firstIVs[index].lBound = firstPloop.getLowerBound()[index];
829 firstIVs[index].uBound = firstPloop.getUpperBound()[index];
830 firstIVs[index].step = firstPloop.getStep()[index];
831 secondIVs[index].lBound = secondPloop.getLowerBound()[index];
832 secondIVs[index].uBound = secondPloop.getUpperBound()[index];
833 secondIVs[index].step = secondPloop.getStep()[index];
834 unique.insert(firstIVs[index]);
835 }
836
837 SmallVector<bool> diffIVs(firstPloop.getNumLoops());
838 llvm::transform(
839 llvm::zip(firstIVs, secondIVs), diffIVs.begin(),
840 [](auto const &pair) { return std::get<0>(pair) != std::get<1>(pair); });
841
843 for (auto [idx, val] : enumerate(diffIVs))
844 if (val)
845 indices.push_back(idx);
846
847 // Not a permutation shortcut
848 if (indices.size() == 1)
849 return {};
850
851 // Initialize with identity permutations
852 SmallVector<int64_t> basic(firstIVs.size());
853 std::iota(basic.begin(), basic.end(), 0);
854
855 if (indices.empty() && unique.size() == firstIVs.size())
856 return {};
857
858 if (indices.size() > 1) {
859 // Determine whether the iteration space of the first loop is a permutation
860 // of the second and collect remaps.
862 for (auto fIdx : indices) {
863 for (auto sIdx : indices) {
864 // can be remapped
865 if (fIdx != sIdx && firstIVs[fIdx] == secondIVs[sIdx] &&
866 remaps.end() == std::find(remaps.begin(), remaps.end(), sIdx)) {
867 remaps.push_back(sIdx);
868 break;
869 }
870 }
871 }
872
873 // Not a permutation
874 if (indices.size() != remaps.size())
875 return {};
876
877 // compose permutation indices
878 for (auto [from, to] : zip(indices, remaps)) {
879 basic[from] = to;
880 }
881
882 LDBG() << "Collected basic permutations: "
883 << llvm::interleaved_array(basic);
884
885 // All axes are unique, no further permutatons needed
886 if (unique.size() == firstIVs.size()) {
887 return {basic};
888 }
889 }
890
891 //
892 // Permute equal axes
893 assert(unique.size() != firstIVs.size() &&
894 "Expected at least two equal axes");
895
896 // Collect equal axes to groups
897 SmallVector<SmallVector<int64_t>> extraResults{basic};
899 for (auto iv : unique) {
901 for (unsigned index : llvm::seq(firstIVs.size())) {
902 if (firstIVs[index] == iv)
903 group.push_back(index);
904 }
905 if (group.size() > 1)
906 groups.push_back(std::move(group));
907 }
908
909 // Permute axes groups
910 SmallVector<SmallVector<int64_t>> rmpdGroups(groups);
911 bool repeat = true;
912 while (repeat && permBudget) {
913 repeat = false;
914 for (auto const &[group, groupRemaps] : zip(groups, rmpdGroups)) {
915 repeat |= std::next_permutation(groupRemaps.begin(), groupRemaps.end());
916 if (repeat)
917 break;
918 }
919
920 if (repeat) {
921 SmallVector<int64_t> extra(basic);
922 for (auto const &[group, groupRemaps] : zip(groups, rmpdGroups)) {
923 for (auto [from, to] : zip(group, groupRemaps))
924 extra[from] = basic[to];
925 }
926 if (basic != extra) {
927 LDBG() << "Collected extra permutations: "
928 << llvm::interleaved_array(extra);
929
930 extraResults.push_back(std::move(extra));
931 permBudget--;
932 }
933 }
934 }
935
936 return extraResults;
937}
938
939/// Prepend operations of firstPloop's body into secondPloop's body.
940/// Update secondPloop with new loop.
941static void applyLoopFusion(ParallelOp &firstPloop, ParallelOp &secondPloop,
942 OpBuilder &builder) {
943 Block *block1 = firstPloop.getBody();
944 Block *block2 = secondPloop.getBody();
945 ValueRange inits1 = firstPloop.getInitVals();
946 ValueRange inits2 = secondPloop.getInitVals();
947
948 SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
949 newInitVars.append(inits2.begin(), inits2.end());
950
951 IRRewriter b(builder);
952 b.setInsertionPoint(secondPloop);
953 auto newSecondPloop = ParallelOp::create(
954 b, secondPloop.getLoc(), secondPloop.getLowerBound(),
955 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
956
957 Block *newBlock = newSecondPloop.getBody();
958 auto term1 = cast<ReduceOp>(block1->getTerminator());
959 auto term2 = cast<ReduceOp>(block2->getTerminator());
960
961 b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
962 newBlock->getArguments());
963 b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
964 newBlock->getArguments());
965
966 ValueRange results = newSecondPloop.getResults();
967 if (!results.empty()) {
968 b.setInsertionPointToEnd(newBlock);
969
970 ValueRange reduceArgs1 = term1.getOperands();
971 ValueRange reduceArgs2 = term2.getOperands();
972 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
973 newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
974
975 auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs);
976
977 for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
978 term1.getReductions(), term2.getReductions()))) {
979 Block &oldRedBlock = reg.front();
980 Block &newRedBlock = newReduceOp.getReductions()[i].front();
981 b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
982 newRedBlock.getArguments());
983 }
984
985 firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
986 secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
987 }
988 term1->erase();
989 term2->erase();
990 firstPloop.erase();
991 secondPloop.erase();
992 secondPloop = newSecondPloop;
993}
994
995/// Check fusion pre-conditions and call fusion if it is possible
996static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
997 OpBuilder builder,
999 Block *block1 = firstPloop.getBody();
1000 Block *block2 = secondPloop.getBody();
1001 IRMapping firstToSecondPloopIndices;
1002 firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
1003
1004 if (isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
1005 mayAlias, builder)) {
1006 applyLoopFusion(firstPloop, secondPloop, builder);
1007 return;
1008 }
1009
1010 // If iteration space of the second parallel loop is a permutation of the
1011 // first one then interchange iteration space of the second parallel loop
1012 // and re-asses possibility of fusion.
1013 for (auto &perms :
1014 computeCandidateInterchangePermutations(firstPloop, secondPloop)) {
1015 OpBuilder::InsertionGuard guard(builder);
1016 LDBG() << "Applied permutation: " << llvm::interleaved_array(perms);
1017
1018 auto newLoop = interchangeLoops(builder, secondPloop, perms);
1019 firstToSecondPloopIndices.clear();
1020 firstToSecondPloopIndices.map(block1->getArguments(),
1021 newLoop->getBody()->getArguments());
1022 if (!isFusionLegal(firstPloop, *newLoop, firstToSecondPloopIndices,
1023 mayAlias, builder)) {
1024 LDBG() << "Rejected: " << newLoop;
1025
1026 newLoop->erase();
1027 continue;
1028 }
1029
1030 secondPloop.replaceAllUsesWith(newLoop->getResults());
1031 secondPloop->erase();
1032 secondPloop = *newLoop;
1033 applyLoopFusion(firstPloop, secondPloop, builder);
1034 break;
1035 }
1036}
1037
1039 Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
1040 OpBuilder b(region);
1041 // Consider every single block and attempt to fuse adjacent loops.
1043 for (auto &block : region) {
1044 ploopChains.clear();
1045 ploopChains.push_back({});
1046
1047 // Not using `walk()` to traverse only top-level parallel loops and also
1048 // make sure that there are no side-effecting ops between the parallel
1049 // loops.
1050 bool noSideEffects = true;
1051 for (auto &op : block) {
1052 if (auto ploop = dyn_cast<ParallelOp>(op)) {
1053 if (noSideEffects) {
1054 ploopChains.back().push_back(ploop);
1055 } else {
1056 ploopChains.push_back({ploop});
1057 noSideEffects = true;
1058 }
1059 continue;
1060 }
1061 // TODO: Handle region side effects properly.
1062 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
1063 }
1064 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
1065 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
1066 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
1067 }
1068 }
1069}
1070
1071namespace {
1072struct ParallelLoopFusion
1073 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
1074 void runOnOperation() override {
1075 auto &aa = getAnalysis<AliasAnalysis>();
1076
1077 auto mayAlias = [&](Value val1, Value val2) -> bool {
1078 // If the memref is defined in one of the parallel loops body, careful
1079 // alias analysis is needed.
1080 // TODO: check if this is still needed as a separate check.
1081 auto val1Def = val1.getDefiningOp();
1082 auto val2Def = val2.getDefiningOp();
1083 auto val1Loop =
1084 val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
1085 auto val2Loop =
1086 val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
1087 if (val1Loop != val2Loop)
1088 return true;
1089
1090 return !aa.alias(val1, val2).isNo();
1091 };
1092
1093 getOperation()->walk([&](Operation *child) {
1094 for (Region &region : child->getRegions())
1096 });
1097 }
1098};
1099} // namespace
1100
1101std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
1102 return std::make_unique<ParallelLoopFusion>();
1103}
return success()
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static bool canResolveAlias(Operation *loadOp, Operation *storeOp, const IRMapping &loopsIVsMap)
To be called when mayAlias(val1, val2) is true.
static std::optional< ParallelOp > interchangeLoops(OpBuilder &builder, ParallelOp &loop, const ArrayRef< int64_t > &indices)
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase, ValueRange writeIndices, VectorType vecTy, ArrayRef< int64_t > vectorDimForWriteDim, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector write (vector.transfer_write or vecto...
static bool loadMatchesVectorWrite(memref::LoadOp loadOp, vector::TransferWriteOp writeOp, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector.transfer_write.
static std::optional< int64_t > getAddConstant(Value expr, Value base, const IRMapping &loopsIVsMap)
If the expr value is the result of an integer addition of base and a constant, return the constant.
static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2, const IRMapping &loopsIVsMap, OpBuilder &b)
Check if both memory read/write operations access the same indices (considering also the mapping of i...
static Value getStoreOpTargetBuffer(Operation *op)
static void applyLoopFusion(ParallelOp &firstPloop, ParallelOp &secondPloop, OpBuilder &builder)
Prepend operations of firstPloop's body into secondPloop's body.
static bool haveNoDataDependenciesExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check that the parallel loops have no mixed access to the same buffers.
static Value getBaseMemref(Operation *op)
Return the base memref value used by the given memory op.
static bool loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp, const IRMapping &firstToSecondPloopIVsMap, OpBuilder &b)
Check if the loadOp reads from the same memory location (same buffer, same indices and same propertie...
static SmallVector< SmallVector< int64_t > > computeCandidateInterchangePermutations(ParallelOp &firstPloop, ParallelOp &secondPloop, int permBudget=120)
static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset, Value writeIndex, int64_t extent, const IRMapping &loopsIVsMap)
static bool opsWriteSameMemLocation(Operation *op1, Operation *op2)
Check if both operations are the same type of memory write op and write to the same memory location (...
static bool noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check that in each loop there are no read ops on the buffers written by the other loop,...
static bool valsAreEquivalent(Value val1, Value val2, const IRMapping &loopsIVsMap)
Check if val1 (from the first parallel loop) and val2 (from the second) are equivalent,...
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check if fusion of the two parallel loops is legal: i.e.
static bool opsAccessSameIndicesViaRankReducingSubview(OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap, OpBuilder &b)
Check if both operations access the same positions of the same buffer, but one of the two does it thr...
static bool loadMatchesVectorStore(memref::LoadOp loadOp, vector::StoreOp storeOp, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector.store.
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Check fusion pre-conditions and call fusion if it is possible.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator begin()
Definition Block.h:153
A class for computing basic dominance information.
Definition Dominance.h:143
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void clear()
Clears all mappings held by the mapper.
Definition IRMapping.h:79
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:242
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
void naivelyFuseParallelOps(Region &region, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SmallVector< std::tuple< int64_t, int64_t, int64_t > > getConstLoopBounds(mlir::LoopLikeOpInterface loopOp)
Get constant loop bounds and steps for each of the induction variables of the given loop operation,...
Definition Utils.cpp:1609
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition MemRefUtils.h:26
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
bool operator==(LoopIV const &other) const
bool operator!=(LoopIV const &other) const
static bool isEqual(const LoopIV &lhs, const LoopIV &rhs)
static unsigned getHashValue(const LoopIV &val)
The following effect indicates that the operation frees some resource that has been allocated.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.