MLIR  18.0.0git
SubsetHoisting.cpp
Go to the documentation of this file.
1 //===- SubsetHoisting.cpp - Linalg hoisting transformations----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functions concerned with hoisting invariant subset
10 // operations in the context of Linalg transformations.
11 //
12 //===----------------------------------------------------------------------===//
13 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 #define DEBUG_TYPE "subset-hoisting"
29 
30 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
31 
32 using namespace mlir;
33 using namespace mlir::linalg;
34 
35 /// Return true if the location of the subset defined by the op is invariant of
36 /// the loop iteration.
37 static bool
39  vector::TransferWriteOp transferWriteOp) {
40  for (Value operand : transferWriteOp.getIndices())
41  if (!forOp.isDefinedOutsideOfLoop(operand))
42  return false;
43  return true;
44 }
45 
46 /// Return true if the location of the subset defined by the op is invariant of
47 /// the loop iteration.
48 static bool isSubsetLocationLoopInvariant(scf::ForOp forOp,
49  tensor::InsertSliceOp insertSliceOp) {
50  for (Value operand : insertSliceOp->getOperands().drop_front(
51  tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex()))
52  if (!forOp.isDefinedOutsideOfLoop(operand))
53  return false;
54  return true;
55 }
56 
57 /// Given an `srcTensor` that is a block argument belong to a loop.
58 /// Greedily look for the first read that can be hoisted out of the loop (i.e.
59 /// that satisfied the conditions):
60 /// - The read is of type `tensor.extract_slice`.
61 /// - The read is one of the uses of `srcTensor`.
62 /// - The read is to the same subset that `tensor.insert_slice` writes.
63 // TODO: Unify implementations once the "bypassing behavior" is the same.
66  tensor::InsertSliceOp insertSliceOp,
67  BlockArgument srcTensor) {
68  assert(isa<RankedTensorType>(srcTensor.getType()) && "not a ranked tensor");
69 
70  auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
71 
72  LLVM_DEBUG(DBGS() << "--find matching read for: " << insertSliceOp << "\n";
73  DBGS() << "--amongst users of: " << srcTensor << "\n");
74 
75  SmallVector<Operation *> users(srcTensor.getUsers());
76  if (forOp.isDefinedOutsideOfLoop(insertSliceOp.getDest()))
77  llvm::append_range(users, insertSliceOp.getDest().getUsers());
78 
79  for (Operation *user : users) {
80  LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
81  auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
82  // Skip ops other than extract_slice with an exact matching of their tensor
83  // subset.
84  if (extractSliceOp) {
85  auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
86  if (extractSliceOp.getResultType() != insertSliceOp.getSourceType() ||
87  !extractSliceOp.isSameAs(insertSliceOp, isSame)) {
88  LLVM_DEBUG(DBGS() << "------not a matching extract_slice\n";
89  DBGS() << *user << " vs " << *insertSliceOp << "\n");
90  continue;
91  }
92 
93  // Skip insert_slice whose vector is defined within the loop: we need to
94  // hoist that definition first otherwise dominance violations trigger.
95  if (!isa<BlockArgument>(extractSliceOp.getSource()) &&
96  !forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
97  LLVM_DEBUG(DBGS() << "------transfer_read vector is loop-dependent\n");
98  continue;
99  }
100  return extractSliceOp;
101  }
102 
103  // TODO: Look through disjoint subsets, similar to vector.transfer_write
104  // and unify implementations.
105  }
106 
107  LLVM_DEBUG(DBGS() << "----no matching extract_slice");
108  return failure();
109 }
110 
111 /// Given an `srcTensor` that is a block argument belong to a loop.
112 /// Greedily look for the first read that can be hoisted out of the loop (i.e.
113 /// that satisfied the conditions):
114 /// - The read is of type `tensor.transfer_read`.
115 /// - The read is one of the uses of `srcTensor`.
116 /// - The read is to the same subset that `tensor.transfer_write` writes.
117 // TODO: Unify implementations once the "bypassing behavior" is the same.
120  vector::TransferWriteOp transferWriteOp,
121  BlockArgument srcTensor) {
122  if (!isa<RankedTensorType>(srcTensor.getType()))
123  return failure();
124 
125  auto forOp = cast<scf::ForOp>(srcTensor.getOwner()->getParentOp());
126 
127  LLVM_DEBUG(DBGS() << "--find matching read for: " << transferWriteOp << "\n";
128  DBGS() << "--amongst users of: " << srcTensor << "\n";);
129 
130  // vector.transfer_write is a bit peculiar: we look through dependencies
131  // to disjoint tensor subsets. This requires a while loop.
132  // TODO: Look through disjoint subsets for tensor.insert_slice and unify
133  // implementations.
134  SmallVector<Operation *> users(srcTensor.getUsers());
135  // TODO: transferWriteOp.getSource is actually the destination tensor!!
136  if (forOp.isDefinedOutsideOfLoop(transferWriteOp.getSource()))
137  llvm::append_range(users, transferWriteOp.getSource().getUsers());
138  while (!users.empty()) {
139  Operation *user = users.pop_back_val();
140  LLVM_DEBUG(DBGS() << "----inspect user: " << *user << "\n");
141  auto read = dyn_cast<vector::TransferReadOp>(user);
142  if (read) {
143  // Skip ops other than transfer_read with an exact matching subset.
144  if (read.getIndices() != transferWriteOp.getIndices() ||
145  read.getVectorType() != transferWriteOp.getVectorType()) {
146  LLVM_DEBUG(DBGS() << "------not a transfer_read that matches the "
147  "transfer_write: "
148  << *user << "\n\t(vs " << *transferWriteOp << ")\n");
149  continue;
150  }
151 
152  // transfer_read may be of a vector that is defined within the loop: we
153  // traverse it by virtue of bypassing disjoint subset operations rooted at
154  // a bbArg and yielding a matching yield.
155  if (!isa<BlockArgument>(read.getSource()) &&
156  !forOp.isDefinedOutsideOfLoop(read.getSource())) {
157  LLVM_DEBUG(DBGS() << "------transfer_read vector appears loop "
158  "dependent but will be tested for disjointness as "
159  "part of the bypass analysis\n");
160  }
161  LLVM_DEBUG(DBGS() << "------found match\n");
162  return read;
163  }
164 
165  // As an optimization, we look further through dependencies to disjoint
166  // tensor subsets. This creates more opportunities to find a matching read.
167  if (isa<vector::TransferWriteOp>(user)) {
168  // If we find a write with disjoint indices append all its uses.
169  // TODO: Generalize areSubsetsDisjoint and allow other bypass than
170  // just vector.transfer_write - vector.transfer_write.
172  cast<VectorTransferOpInterface>(user),
173  cast<VectorTransferOpInterface>(
174  transferWriteOp.getOperation()))) {
175  LLVM_DEBUG(DBGS() << "----follow through disjoint write\n");
176  users.append(user->getUsers().begin(), user->getUsers().end());
177  } else {
178  LLVM_DEBUG(DBGS() << "----skip non-disjoint write\n");
179  }
180  }
181  }
182 
183  LLVM_DEBUG(DBGS() << "--no matching transfer_read\n");
184  return rewriter.notifyMatchFailure(transferWriteOp,
185  "no matching transfer_read");
186 }
187 
188 /// Return the `vector.transfer_write` that produces `yieldOperand`, if:
189 /// - The write operates on tensors.
190 /// - All indices are defined outside of the loop.
191 /// Return failure otherwise.
192 ///
193 /// This is sufficient condition to hoist the `vector.transfer_write`; other
194 /// operands can always be yielded by the loop where needed.
195 // TODO: generalize beyond scf::ForOp.
196 // TODO: Unify implementations once the "bypassing behavior" is the same.
199  BlockArgument bbArg,
200  OpOperand &yieldOperand) {
201  assert(bbArg.getArgNumber() ==
202  forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
203  "bbArg and yieldOperand must match");
204  assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
205 
206  Value v = yieldOperand.get();
207  auto transferWriteOp = v.getDefiningOp<vector::TransferWriteOp>();
208  if (!transferWriteOp)
209  return rewriter.notifyMatchFailure(v.getLoc(), "not a transfer_write");
210 
211  if (transferWriteOp->getNumResults() == 0) {
212  return rewriter.notifyMatchFailure(v.getLoc(),
213  "unsupported transfer_write on buffers");
214  }
215 
216  // We do not explicitly check that the destination is a BBarg that matches the
217  // yield operand as this would prevent us from bypassing other non-conflicting
218  // writes.
219 
220  // Indexing must not depend on `forOp`.
221  if (!isSubsetLocationLoopInvariant(forOp, transferWriteOp))
222  return rewriter.notifyMatchFailure(
223  v.getLoc(), "transfer_write indexing is loop-dependent");
224 
225  return transferWriteOp;
226 }
227 
228 /// Return the `tensor.insert_slice` that produces `yieldOperand`, if:
229 /// 1. Its destination tensor is a block argument of the `forOp`.
230 /// 2. The unique use of its result is a yield with operand number matching
231 /// the block argument.
232 /// 3. All indices are defined outside of the loop.
233 /// Return failure otherwise.
234 ///
235 /// This is sufficient condition to hoist the `tensor.insert_slice`; other
236 /// operands can always be yielded by the loop where needed.
237 /// Note: 1. + 2. ensure that the yield / iter_args cycle results in proper
238 /// semantics (i.e. no ping-ping between iter_args across iterations).
239 // TODO: generalize beyond scf::ForOp.
240 // TODO: Unify implementations once the "bypassing behavior" is the same.
243  BlockArgument bbArg,
244  OpOperand &yieldOperand) {
245  assert(bbArg.getArgNumber() ==
246  forOp.getNumInductionVars() + yieldOperand.getOperandNumber() &&
247  "bbArg and yieldOperand must match");
248  assert(isa<scf::YieldOp>(yieldOperand.getOwner()) && "must be an scf.yield");
249 
250  Value v = yieldOperand.get();
251  auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>();
252  if (!insertSliceOp)
253  return rewriter.notifyMatchFailure(v.getLoc(), "not an insert_slice");
254 
255  // Tensor inserted into must be a BBArg at position matching yield operand.
256  // TODO: In the future we should not perform this check if we want to bypass
257  // other non-conflicting writes.
258  if (bbArg != insertSliceOp.getDest())
259  return rewriter.notifyMatchFailure(v.getLoc(), "not a matching bbarg");
260 
261  // Indexing inserted into must not depend on `forOp`.
262  if (!isSubsetLocationLoopInvariant(forOp, insertSliceOp))
263  return rewriter.notifyMatchFailure(
264  v.getLoc(), "insert_slice indexing is loop-dependent");
265 
266  return insertSliceOp;
267 }
268 
269 /// Check if the chunk of data inserted by the `writeOp` is read by any other
270 /// op than the candidateReadOp. This conflicting operation prevents hoisting,
271 /// return it or nullptr if none is found.
272 // TODO: Generalize subset disjunction analysis/interface.
273 // TODO: Support more subset op types.
275  Operation *candidateReadOp,
276  BlockArgument tensorArg) {
277  // Make sure none of the other uses read the part of the tensor modified
278  // by the transfer_write.
280  uses.push_back(tensorArg.getUses());
281  while (!uses.empty()) {
282  for (OpOperand &use : uses.pop_back_val()) {
283  Operation *user = use.getOwner();
284  // Skip the candidate use, only inspect the "other" uses.
285  if (user == candidateReadOp || user == writeOp)
286  continue;
287 
288  // TODO: Consider all transitive uses through
289  // extract_slice/insert_slice. Atm we just bail because a stronger
290  // analysis is needed for these cases.
291  if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user))
292  return user;
293 
294  // Consider all transitive uses through a vector.transfer_write.
295  if (isa<vector::TransferWriteOp>(writeOp)) {
296  if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
297  uses.push_back(writeUser->getResult(0).getUses());
298  continue;
299  }
300  }
301 
302  // Consider all nested uses through an scf::ForOp. We may have
303  // pass-through tensor arguments left from previous level of
304  // hoisting.
305  if (auto forUser = dyn_cast<scf::ForOp>(user)) {
306  Value arg = forUser.getBody()->getArgument(
307  use.getOperandNumber() - forUser.getNumControlOperands() +
308  /*iv value*/ 1);
309  uses.push_back(arg.getUses());
310  continue;
311  }
312 
313  // Follow the use yield, only if it doesn't escape the original region.
314  scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
315  if (yieldUser &&
316  writeOp->getParentOp()->isAncestor(yieldUser->getParentOp())) {
317  Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
318  uses.push_back(ret.getUses());
319  continue;
320  }
321 
322  // If the write is a vector::TransferWriteOp, it may have been bypassed
323  // and we need to check subset disjunction
324  if (isa<vector::TransferWriteOp>(writeOp)) {
325  auto read = dyn_cast<vector::TransferReadOp>(user);
326  if (!read || !vector::isDisjointTransferIndices(
327  cast<VectorTransferOpInterface>(read.getOperation()),
328  cast<VectorTransferOpInterface>(writeOp))) {
329  return user;
330  }
331  }
332  }
333  }
334  return nullptr;
335 }
336 
337 /// Mechanical hoisting of a matching read / write pair.
338 /// Return the newly created scf::ForOp with an extra yields.
339 // TODO: Unify implementations once the "bypassing behavior" is the same.
340 static scf::ForOp hoistTransferReadWrite(
341  RewriterBase &rewriter, vector::TransferReadOp transferReadOp,
342  vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg) {
343  scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
344  LLVM_DEBUG(DBGS() << "--Start hoisting\n";
345  DBGS() << "--Hoist read : " << transferReadOp << "\n";
346  DBGS() << "--Hoist write: " << transferWriteOp << "\n";
347  DBGS() << "--Involving : " << tensorBBArg << "\n");
348 
349  // TODO: don't hardcode /*numIvs=*/1.
350  assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
351  int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
352 
353  // 1. Hoist the read op. Thanks to our previous checks we know this will not
354  // trigger dominance violations once BBArgs are updated.
355  // TODO: should the rewriter ever want to track this move ?
356  transferReadOp->moveBefore(forOp);
357  if (!forOp.isDefinedOutsideOfLoop(transferReadOp.getSource())) {
358  rewriter.startRootUpdate(transferReadOp);
359  transferReadOp.getSourceMutable().assign(
360  forOp.getInitArgs()[initArgNumber]);
361  rewriter.finalizeRootUpdate(transferReadOp);
362  }
363 
364  // 2. Rewrite `loop` with an additional yield. This is the quantity that is
365  // computed iteratively but whose storage has become loop-invariant.
366  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
367  ArrayRef<BlockArgument> newBBArgs) {
368  return SmallVector<Value>{transferWriteOp.getVector()};
369  };
370  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
371  rewriter, {transferReadOp.getVector()},
372  /*replaceInitOperandUsesInLoop=*/true, yieldFn));
373 
374  // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
375  auto yieldOp =
376  cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
377  // TODO: transferWriteOp.getSource is actually the destination tensor!!
378  rewriter.startRootUpdate(yieldOp);
379  yieldOp->setOperand(initArgNumber, transferWriteOp.getSource());
380  rewriter.finalizeRootUpdate(yieldOp);
381 
382  // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
383  // flow through it.
384  // TODO: should the rewriter ever want to track this move ?
385  transferWriteOp->moveAfter(newForOp);
386  rewriter.startRootUpdate(transferWriteOp);
387  transferWriteOp.getVectorMutable().assign(newForOp.getResults().back());
388  // TODO: transferWriteOp.getSource is actually the destination tensor!!
389  transferWriteOp.getSourceMutable().assign(newForOp.getResult(initArgNumber));
390  rewriter.finalizeRootUpdate(transferWriteOp);
391  rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
392  transferWriteOp.getResult(), transferWriteOp);
393  return newForOp;
394 }
395 
396 /// Mechanical hoisting of a matching read / write pair.
397 /// Return the newly created scf::ForOp with an extra yields.
398 // TODO: Unify implementations once the "bypassing behavior" is the same.
399 static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
400  tensor::ExtractSliceOp extractSliceOp,
401  tensor::InsertSliceOp insertSliceOp,
402  BlockArgument tensorBBArg) {
403  scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
404  LLVM_DEBUG(DBGS() << "--Start hoisting\n";
405  DBGS() << "--Hoist read : " << extractSliceOp << "\n";
406  DBGS() << "--Hoist write: " << insertSliceOp << "\n";
407  DBGS() << "--Involving : " << tensorBBArg << "\n");
408 
409  // TODO: don't hardcode /*numIvs=*/1.
410  assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
411  int64_t initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
412 
413  // 1. Hoist the read op. Thanks to our previous checks we know this will not
414  // trigger dominance violations once BBArgs are updated.
415  // TODO: should the rewriter ever want to track this move ?
416  extractSliceOp->moveBefore(forOp);
417  if (!forOp.isDefinedOutsideOfLoop(extractSliceOp.getSource())) {
418  assert(extractSliceOp.getSource() == tensorBBArg &&
419  "extractSlice source not defined above must be the tracked bbArg");
420  rewriter.startRootUpdate(extractSliceOp);
421  extractSliceOp.getSourceMutable().assign(
422  forOp.getInitArgs()[initArgNumber]);
423  rewriter.finalizeRootUpdate(extractSliceOp);
424  }
425 
426  // 2. Rewrite `loop` with an additional yield. This is the quantity that is
427  // computed iteratively but whose storage has become loop-invariant.
428  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
429  ArrayRef<BlockArgument> newBBArgs) {
430  return SmallVector<Value>{insertSliceOp.getSource()};
431  };
432  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
433  rewriter, extractSliceOp.getResult(),
434  /*replaceInitOperandUsesInLoop=*/true, yieldFn));
435 
436  // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
437  auto yieldOp =
438  cast<scf::YieldOp>(newForOp.getRegion().front().getTerminator());
439  // TODO: should the rewriter ever want to track this ?
440  rewriter.startRootUpdate(yieldOp);
441  yieldOp->setOperand(initArgNumber, insertSliceOp.getDest());
442  rewriter.finalizeRootUpdate(yieldOp);
443 
444  // 4. Hoist write after and make uses of newForOp.getResult(initArgNumber)
445  // flow through it.
446  // TODO: should the rewriter ever want to track this move ?
447  insertSliceOp->moveAfter(newForOp);
448  rewriter.startRootUpdate(insertSliceOp);
449  insertSliceOp.getSourceMutable().assign(newForOp.getResults().back());
450  insertSliceOp.getDestMutable().assign(newForOp.getResult(initArgNumber));
451  rewriter.finalizeRootUpdate(insertSliceOp);
452  rewriter.replaceAllUsesExcept(newForOp.getResult(initArgNumber),
453  insertSliceOp.getResult(), insertSliceOp);
454  return newForOp;
455 }
456 
457 /// Greedily hoist redundant subset extract/insert operations on tensors
458 /// outside `forOp`.
459 /// Return the unmodified `forOp` if no hoisting occurred.
460 /// Return a new scf::ForOp if hoisting on tensors occurred.
461 scf::ForOp
463  scf::ForOp forOp) {
464  LLVM_DEBUG(DBGS() << "Enter hoistRedundantSubsetExtractInsert scf.for\n");
465  Operation *yield = forOp.getBody()->getTerminator();
466 
467  LLVM_DEBUG(DBGS() << "\n"; DBGS() << "Consider " << forOp << "\n");
468 
469  scf::ForOp newForOp = forOp;
470  do {
471  forOp = newForOp;
472  for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) {
473  LLVM_DEBUG(DBGS() << "Consider " << it.value() << "\n");
474 
475  // 1. Find a loop invariant subset write yielding `ret` that we can
476  // consider for hoisting.
477  // TODO: TypeSwitch when we add more cases.
478  OpOperand &ret = yield->getOpOperand(it.index());
479  FailureOr<vector::TransferWriteOp> transferWriteOp =
480  getLoopInvariantTransferWriteDefining(rewriter, forOp, it.value(),
481  ret);
482  FailureOr<tensor::InsertSliceOp> insertSliceOp =
483  getLoopInvariantInsertSliceDefining(rewriter, forOp, it.value(), ret);
484  if (failed(transferWriteOp) && failed(insertSliceOp)) {
485  LLVM_DEBUG(DBGS() << "no loop invariant write defining iter_args "
486  << it.value() << "\n");
487  continue;
488  }
489 
490  Operation *writeOp = succeeded(transferWriteOp)
491  ? transferWriteOp->getOperation()
492  : insertSliceOp->getOperation();
493 
494  // 2. Only accept writes with a single use (i.e. the yield).
495  if (!writeOp->hasOneUse()) {
496  LLVM_DEBUG(DBGS() << "write with more than 1 use " << *writeOp << "\n");
497  continue;
498  }
499 
500  LLVM_DEBUG(DBGS() << "Write to hoist: " << *writeOp << "\n");
501 
502  // 3. Find a matching read that can also be hoisted.
503  Operation *matchingReadOp = nullptr;
504  // TODO: TypeSwitch.
505  if (succeeded(transferWriteOp)) {
506  auto maybeTransferRead = findHoistableMatchingTransferRead(
507  rewriter, *transferWriteOp, it.value());
508  if (succeeded(maybeTransferRead))
509  matchingReadOp = maybeTransferRead->getOperation();
510  } else if (succeeded(insertSliceOp)) {
511  auto maybeExtractSlice = findHoistableMatchingExtractSlice(
512  rewriter, *insertSliceOp, it.value());
513  if (succeeded(maybeExtractSlice))
514  matchingReadOp = maybeExtractSlice->getOperation();
515  } else {
516  llvm_unreachable("unexpected case");
517  }
518  if (!matchingReadOp) {
519  LLVM_DEBUG(DBGS() << "No matching read\n");
520  continue;
521  }
522 
523  // 4. Make sure no other use reads the part of the modified tensor.
524  // This is necessary to guard against hazards when non-conflicting subset
525  // ops are bypassed.
526  Operation *maybeUnknownOp =
527  isTensorChunkAccessedByUnknownOp(writeOp, matchingReadOp, it.value());
528  if (maybeUnknownOp) {
529  LLVM_DEBUG(DBGS() << "Tensor chunk accessed by unknown op, skip: "
530  << *maybeUnknownOp << "\n");
531  continue;
532  }
533 
534  // 5. Perform the actual mechanical hoisting.
535  // TODO: TypeSwitch.
536  LLVM_DEBUG(DBGS() << "Read to hoist: " << *matchingReadOp << "\n");
537  if (succeeded(transferWriteOp)) {
538  newForOp = hoistTransferReadWrite(
539  rewriter, cast<vector::TransferReadOp>(matchingReadOp),
540  *transferWriteOp, it.value());
541  } else if (succeeded(insertSliceOp)) {
542  newForOp = hoistExtractInsertSlice(
543  rewriter, cast<tensor::ExtractSliceOp>(matchingReadOp),
544  *insertSliceOp, it.value());
545  } else {
546  llvm_unreachable("unexpected case");
547  }
548  break;
549  }
550  } while (forOp != newForOp);
551 
552  return newForOp;
553 }
static FailureOr< vector::TransferWriteOp > getLoopInvariantTransferWriteDefining(RewriterBase &rewriter, scf::ForOp forOp, BlockArgument bbArg, OpOperand &yieldOperand)
Return the vector.transfer_write that produces yieldOperand, if:
static FailureOr< vector::TransferReadOp > findHoistableMatchingTransferRead(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, BlockArgument srcTensor)
Given an srcTensor that is a block argument belong to a loop.
static bool isSubsetLocationLoopInvariant(scf::ForOp forOp, vector::TransferWriteOp transferWriteOp)
Return true if the location of the subset defined by the op is invariant of the loop iteration.
static scf::ForOp hoistTransferReadWrite(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, vector::TransferWriteOp transferWriteOp, BlockArgument tensorBBArg)
Mechanical hoisting of a matching read / write pair.
static FailureOr< tensor::InsertSliceOp > getLoopInvariantInsertSliceDefining(RewriterBase &rewriter, scf::ForOp forOp, BlockArgument bbArg, OpOperand &yieldOperand)
Return the tensor.insert_slice that produces yieldOperand, if:
static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter, tensor::ExtractSliceOp extractSliceOp, tensor::InsertSliceOp insertSliceOp, BlockArgument tensorBBArg)
Mechanical hoisting of a matching read / write pair.
static FailureOr< tensor::ExtractSliceOp > findHoistableMatchingExtractSlice(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp, BlockArgument srcTensor)
Given an srcTensor that is a block argument belong to a loop.
#define DBGS()
static Operation * isTensorChunkAccessedByUnknownOp(Operation *writeOp, Operation *candidateReadOp, BlockArgument tensorArg)
Check if the chunk of data inserted by the writeOp is read by any other op than the candidateReadOp.
This class represents an argument of a Block.
Definition: Value.h:310
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:319
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:322
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:828
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:646
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:591
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:206
user_range getUsers() const
Definition: Value.h:222
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
scf::ForOp hoistRedundantSubsetExtractInsert(RewriterBase &rewriter, scf::ForOp forOp)
Greedily hoist redundant subset extract/insert operations on tensors outside of forOp.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn't require the operations to have the same tensor/mem...
Definition: VectorOps.cpp:170
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72