MLIR  19.0.0git
OneShotAnalysis.cpp
Go to the documentation of this file.
1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. By default, function boundaries
10 // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12 // simple call graphs without loops.
13 //
14 // One-Shot Bufferize consists of three phases.
15 //
16 // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17 // without inserting buffer copies. The analysis queries op bufferization
18 // semantics via `BufferizableOpInterface`.
19 // 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20 // in tensor land during `TensorCopyInsertion`.
21 // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22 //
23 // This file contains only the analysis. For convenience, this file also
24 // contains a helper function `runOneShotBufferize` that analyzes an op (and its
25 // nested ops) and then bufferizes it.
26 //
27 // Inplace bufferization decisions are passed from the analysis to the
28 // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29 // debugging purposes with `testAnalysisOnly`.
30 //
31 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32 // treated conservatively. E.g., the analysis has to assume that their tensor
33 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
34 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
35 // inserted at the bufferization boundary.
36 //
37 // This analysis caters to high-performance codegen where buffer reuse is deemed
38 // critical: the analysis should fail if the bufferized form of the function
39 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
40 
42 
43 #include <optional>
44 #include <random>
45 
52 #include "mlir/IR/AsmState.h"
53 #include "mlir/IR/Dominance.h"
54 #include "mlir/IR/Iterators.h"
55 #include "mlir/IR/Operation.h"
56 #include "mlir/IR/TypeUtilities.h"
59 #include "llvm/ADT/DenseSet.h"
60 #include "llvm/ADT/SetVector.h"
61 
63 
64 // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65 // output.
66 #define DEBUG_TYPE "one-shot-analysis"
67 
68 using namespace mlir;
69 using namespace mlir::bufferization;
70 
71 static bool isaTensor(Type t) { return isa<TensorType>(t); }
72 
73 //===----------------------------------------------------------------------===//
74 // Bufferization-specific attribute manipulation.
75 // These are for testing and debugging only. Bufferization information is stored
76 // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
77 // annotated with the results of the analysis, so that they can be checked in
78 // tests.
79 //===----------------------------------------------------------------------===//
80 
81 /// Attribute marker to specify op operands that bufferize in-place.
82 constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
83 
84 constexpr StringLiteral kOpResultAliasSetAttrName =
85  "__opresult_alias_set_attr__";
86 
87 constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88 
89 /// Mark whether OpOperand will be bufferized inplace.
90 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
91  Operation *op = opOperand.getOwner();
92  SmallVector<StringRef> inPlaceVector;
93  if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
94  inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
95  cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
96  } else {
97  inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
98  for (OpOperand &opOperand : op->getOpOperands())
99  if (isa<TensorType>(opOperand.get().getType()))
100  inPlaceVector[opOperand.getOperandNumber()] = "false";
101  }
102  inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
104  OpBuilder(op).getStrArrayAttr(inPlaceVector));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // OneShotAnalysisState
109 //===----------------------------------------------------------------------===//
110 
114  // Set up alias sets.
115  op->walk([&](Operation *op) {
116  for (Value v : op->getResults())
117  if (isa<TensorType>(v.getType()))
119  for (Region &r : op->getRegions())
120  for (Block &b : r.getBlocks())
121  for (auto bbArg : b.getArguments())
122  if (isa<TensorType>(bbArg.getType()))
123  createAliasInfoEntry(bbArg);
124  });
125 
126  // Mark OpOperands in-place that must bufferize in-place.
127  op->walk([&](BufferizableOpInterface bufferizableOp) {
128  if (!options.isOpAllowed(bufferizableOp))
129  return WalkResult::skip();
130  for (OpOperand &opOperand : bufferizableOp->getOpOperands())
131  if (isa<TensorType>(opOperand.get().getType()))
132  if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133  bufferizeInPlace(opOperand);
134  return WalkResult::advance();
135  });
136 }
137 
139  Value v, function_ref<void(Value)> fun) const {
140  auto leaderIt = equivalentInfo.findLeader(v);
141  for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
142  ++mit) {
143  fun(*mit);
144  }
145 }
146 
148  function_ref<void(Value)> fun) const {
149  auto leaderIt = aliasInfo.findLeader(v);
150  for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
151  fun(*mit);
152  }
153 }
154 
156  Value v2) const {
157  return equivalentInfo.isEquivalent(v1, v2);
158 }
159 
161  Value v2) const {
162  return aliasInfo.isEquivalent(v1, v2);
163 }
164 
166  if (inplaceBufferized.contains(&operand))
167  return;
168  inplaceBufferized.insert(&operand);
169  for (AliasingValue alias : getAliasingValues(operand))
170  aliasInfo.unionSets(alias.value, operand.get());
171  ++statNumTensorInPlace;
172 }
173 
175  assert(!inplaceBufferized.contains(&operand) &&
176  "OpOperand was already decided to bufferize inplace");
177  ++statNumTensorOutOfPlace;
178 }
179 
181  aliasInfo.insert(v);
182  equivalentInfo.insert(v);
183 }
184 
186  op->walk([&](Operation *op) {
187  // Skip unknown ops.
188  auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189  if (!bufferizableOp)
190  return WalkResult::skip();
191 
192  // Check all tensor OpResults.
193  for (OpResult opResult : op->getOpResults()) {
194  if (!isa<TensorType>(opResult.getType()))
195  continue;
196 
197  // If there is no preceding definition, the tensor contents are
198  // undefined.
199  if (findDefinitionsCached(opResult).empty())
200  for (OpOperand &use : opResult.getUses())
201  undefinedTensorUses.insert(&use);
202  }
203 
204  return WalkResult::advance();
205  });
206 }
207 
209  return undefinedTensorUses.contains(opOperand);
210 }
211 
213  return inplaceBufferized.contains(&opOperand);
214 }
215 
217  bool isWritten = false;
218  applyOnAliases(value, [&](Value val) {
219  for (OpOperand &use : val.getUses())
220  if (isInPlace(use) && bufferizesToMemoryWrite(use))
221  isWritten = true;
222  });
223  return isWritten;
224 }
225 
227  // TODO: Out-of-place bufferized value could be considered writable.
228  // Query BufferizableOpInterface to see if the BlockArgument is writable.
229  if (auto bufferizableOp =
230  getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
231  return bufferizableOp.isWritable(value, *this);
232 
233  // Not a bufferizable op: The conservative answer is "not writable".
234  return false;
235 }
236 
238  aliasInfo.unionSets(v1, v2);
239 }
240 
242  equivalentInfo.unionSets(v1, v2);
243 }
244 
246 
247 //===----------------------------------------------------------------------===//
248 // Bufferization-specific alias analysis.
249 //===----------------------------------------------------------------------===//
250 
251 /// Return true if opOperand has been decided to bufferize in-place.
252 static bool isInplaceMemoryWrite(OpOperand &opOperand,
253  const OneShotAnalysisState &state) {
254  // OpOperands that do not bufferize to a memory write do not write in-place.
255  if (!state.bufferizesToMemoryWrite(opOperand))
256  return false;
257  // Check current bufferization decisions.
258  return state.isInPlace(opOperand);
259 }
260 
261 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
262 /// properly dominates `b` and `b` is not inside `a`.
263 static bool happensBefore(Operation *a, Operation *b,
264  const DominanceInfo &domInfo) {
265  do {
266  // TODO: Instead of isProperAncestor + properlyDominates, we should use
267  // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
268  if (a->isProperAncestor(b))
269  return false;
270  if (domInfo.properlyDominates(a, b))
271  return true;
272  } while ((a = a->getParentOp()));
273  return false;
274 }
275 
276 static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
277  DenseSet<Block *> visited;
278  SmallVector<Block *> worklist;
279  for (Block *succ : from->getSuccessors())
280  worklist.push_back(succ);
281  while (!worklist.empty()) {
282  Block *next = worklist.pop_back_val();
283  if (llvm::is_contained(except, next))
284  continue;
285  if (next == to)
286  return true;
287  if (visited.contains(next))
288  continue;
289  visited.insert(next);
290  for (Block *succ : next->getSuccessors())
291  worklist.push_back(succ);
292  }
293  return false;
294 }
295 
296 /// Return `true` if op dominance can be used to rule out a read-after-write
297 /// conflicts based on the ordering of ops. Returns `false` if op dominance
298 /// cannot be used to due region-based loops.
299 ///
300 /// Generalized op dominance can often be used to rule out potential conflicts
301 /// due to "read happens before write". E.g., the following IR is not a RaW
302 /// conflict because the read happens *before* the write.
303 ///
304 /// Example 1:
305 /// %0 = ... : tensor<?xf32> // DEF
306 /// "reading_op"(%0) : tensor<?xf32> // READ
307 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
308 ///
309 /// This is no longer true inside loops (or repetitive regions). In such cases,
310 /// there may not be a meaningful `happensBefore` relationship because ops
311 /// could be executed multiple times. E.g.:
312 ///
313 /// Example 2:
314 /// %0 = ... : tensor<?xf32> // DEF
315 /// scf.for ... {
316 /// "reading_op"(%0) : tensor<?xf32> // READ
317 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
318 /// ...
319 /// }
320 ///
321 /// In the above example, reading_op happens before writing_op according to
322 /// op dominance. However, both ops may happen multiple times; in
323 /// particular, the second execution of reading_op happens after the first
324 /// execution of writing_op. This is problematic because the tensor %0 they
325 /// operate on (i.e., the "definition") is defined outside of the loop.
326 ///
327 /// On a high-level, there is a potential RaW in a program if there exists a
328 /// possible program execution such that there is a sequence of DEF, followed
329 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
330 ///
331 /// E.g.:
332 /// No conflict: DEF, WRITE, DEF, READ
333 /// Potential conflict: DEF, READ, WRITE, READ, WRITE
334 ///
335 /// Example 1 has no conflict: DEF, READ, WRITE
336 /// Example 2 has a potential conflict: DEF, (READ, WRITE)*
337 //
338 /// Example 3:
339 /// scf.for ... {
340 /// %0 = ... : tensor<?xf32>
341 /// "reading_op"(%0) : tensor<?xf32>
342 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
343 /// ...
344 /// }
345 /// This has no conflict: (DEF, READ, WRITE)*
346 ///
347 /// Example 4:
348 /// %0 = ... : tensor<?xf32>
349 /// scf.for ... {
350 /// scf.for ... { "reading_op"(%0) }
351 /// %1 = "writing_op"(%0)
352 /// }
353 /// This has a potential conflict: DEF, ((READ)*, WRITE)*
354 ///
355 /// Example 5:
356 /// %0 = ... : tensor<?xf32>
357 /// scf.for ... { %1 = "writing_op"(%0) }
358 /// scf.for ... { "reading_op"(%0) }
359 /// This has a potential conflict: DEF, WRITE*, READ*
360 ///
361 /// The following rules are used to rule out RaW conflicts via ordering of ops:
362 ///
363 /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
364 /// a repetitive region that enclosing both READ and WRITE, we cannot rule
365 /// out RaW conflict due to the ordering of ops.
366 /// 2. Otherwise: There are no loops that interfere with our analysis; for
367 /// analysis purposes, we can assume that there are no loops/repetitive
368 /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
369 /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
370 ///
372  const SetVector<Value> &definitions,
373  AnalysisState &state) {
374  const BufferizationOptions &options = state.getOptions();
375  for (Value def : definitions) {
376  Region *rRead =
377  state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
378  Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
379 
380  // READ and DEF are in the same repetitive region. `happensBefore` can be
381  // used to rule out RaW conflicts due to op ordering.
382  if (rRead == rDef)
383  continue;
384 
385  // Find the enclosing repetitive region of READ that is closest to DEF but
386  // not the repetitive region of DEF itself.
387  while (true) {
388  Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
389  if (nextRegion == rDef)
390  break;
391  assert(nextRegion && "expected to find another repetitive region");
392  rRead = nextRegion;
393  }
394 
395  // We cannot use op dominance if WRITE is inside the same repetitive region.
396  if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
397  return false;
398  }
399 
400  return true;
401 }
402 
403 /// Return `true` if op dominance can be used to rule out a read-after-write
404 /// conflicts based on the ordering of ops. Returns `false` if op dominance
405 /// cannot be used to due block-based loops within a region.
406 ///
407 /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
408 /// how op domiance is used during RaW conflict detection.
409 ///
410 /// On a high-level, there is a potential RaW in a program if there exists a
411 /// possible program execution such that there is a sequence of DEF, followed
412 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
413 ///
414 /// Op dominance cannot be used if there is a path from block(READ) to
415 /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
416 /// not appear on that path.
418  const SetVector<Value> &definitions,
419  AnalysisState &state) {
420  // Fast path: If READ and WRITE are in different regions, their block cannot
421  // be reachable just via unstructured control flow. (Loops due to regions are
422  // covered by `canUseOpDominanceDueToRegions`.)
423  if (uRead->getOwner()->getParentRegion() !=
424  uWrite->getOwner()->getParentRegion())
425  return true;
426 
427  Block *readBlock = uRead->getOwner()->getBlock();
428  Block *writeBlock = uWrite->getOwner()->getBlock();
429  for (Value def : definitions) {
430  Block *defBlock = def.getParentBlock();
431  if (isReachable(readBlock, writeBlock, {defBlock}) &&
432  isReachable(writeBlock, readBlock, {defBlock}))
433  return false;
434  }
435 
436  return true;
437 }
438 
439 static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
440  const SetVector<Value> &definitions,
441  AnalysisState &state) {
442  return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
443  canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
444 }
445 
446 /// Annotate IR with details about the detected RaW conflict.
447 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
448  Value definition) {
449  static uint64_t counter = 0;
450  Operation *readingOp = uRead->getOwner();
451  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
452 
453  OpBuilder b(conflictingWritingOp->getContext());
454  std::string id = "C_" + std::to_string(counter++);
455 
456  std::string conflictingWriteAttr =
457  id +
458  "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
459  "]";
460  conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
461 
462  std::string readAttr =
463  id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
464  readingOp->setAttr(readAttr, b.getUnitAttr());
465 
466  if (auto opResult = dyn_cast<OpResult>(definition)) {
467  std::string defAttr =
468  id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
469  opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
470  } else {
471  auto bbArg = cast<BlockArgument>(definition);
472  std::string defAttr =
473  id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
474  bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
475  }
476 }
477 
478 /// Return 'true' if a tensor that is equivalent to `other` can be found in the
479 /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
480 /// place along that use-def chain, the two tensors may not materialize as
481 /// equivalent buffers (but separate allocations).
482 ///
483 /// Note: This function also requires that the two tensors have equivalent
484 /// indexing. I.e., the tensor types do not change along the use-def chain,
485 /// apart from static <-> dynamic dim casts.
487  Value start, Value other) {
488  TraversalConfig config;
489  config.followEquivalentOnly = true;
490  config.alwaysIncludeLeaves = false;
491  config.followSameTypeOrCastsOnly = true;
492  return !state
493  .findValueInReverseUseDefChain(
494  start, [&](Value v) { return v == other; }, config)
495  .empty();
496 }
497 
498 /// Return "true" if `value` is originating from a subset that is equivalent to
499 /// the subset that `subsetOp` inserts into.
500 static bool matchesInsertDestination(const AnalysisState &state, Value value,
501  SubsetInsertionOpInterface subsetOp) {
502  auto matchingSubset = [&](Value val) {
503  if (auto opResult = dyn_cast<OpResult>(val))
504  if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
505  return state.areEquivalentBufferizedValues(v1, v2);
506  }))
507  return true;
508  return false;
509  };
510  // There may be multiple leaves at which the reverse SSA use-def chain lookup
511  // terminates. All of them must be equivalent subsets.
512  SetVector<Value> backwardSlice =
513  state.findValueInReverseUseDefChain(value, matchingSubset);
514  return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
515 }
516 
517 /// Return "true" if the given "read" and potentially conflicting "write" are
518 /// not conflicting due to their subset relationship. The comments in this
519 /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
520 /// pairs, but apply to any subset ops that implement the
521 /// `SubsetInsertionOpInterface`.
523  OpOperand *uConflictingWrite,
524  const AnalysisState &state) {
525  Operation *readingOp = uRead->getOwner();
526  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
527 
528  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
529  // uRead is an InsertSliceOp...
530  if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
531  // As an example, consider the following IR.
532  //
533  // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
534  // %1 = linalg.fill %cst, %0 {inplace= [true] }
535  // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
536  // {inplace= [true] }
537 
538  if (uRead == &subsetOp.getDestinationOperand() &&
539  matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
540  // Case 1: The main insight is that InsertSliceOp reads only part of
541  // the destination tensor. The overwritten area is not read. If
542  // uConflictingWrite writes into exactly the memory location that is
543  // being read by uRead, this is not a conflict.
544  //
545  // In the above example:
546  // uRead = OpOperand 1 (%t) of tensor.insert_slice
547  // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
548  //
549  // The read of %t does not conflict with the write of the FillOp
550  // (same aliases!) because the area that the FillOp operates on is
551  // exactly the one that is *not* read via %t.
552  return true;
553 
554  if (uRead == &subsetOp.getSourceOperand() &&
555  uConflictingWrite == &subsetOp.getDestinationOperand() &&
556  matchesInsertDestination(state, uRead->get(), subsetOp))
557  // Case 2: The read of the source tensor and the write to the dest
558  // tensor via an InsertSliceOp is not a conflict if the read is
559  // reading exactly that part of an equivalent tensor that the
560  // InsertSliceOp is writing.
561  //
562  // In the above example:
563  // uRead = OpOperand 0 (%1) of tensor.insert_slice
564  // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
565  return true;
566  }
567 
568  // If uConflictingWrite is an InsertSliceOp...
569  if (auto subsetOp =
570  dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
571  // As an example, consider the following IR.
572  //
573  // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
574  // %1 = linalg.fill %cst, %0 {inplace= [true] }
575  // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
576  // {inplace= [true] }
577  // %3 = vector.transfer_read %1, %cst
578  //
579  // In the above example:
580  // uRead = OpOperand 0 (%1) of vector.transfer_read
581  // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
582  // definition = %1
583  //
584  // This is not a conflict because the InsertSliceOp overwrites the
585  // memory segment of %1 with the exact same data. (Effectively, there
586  // is no memory write here.)
587  if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
588  state.areEquivalentBufferizedValues(
589  uRead->get(), subsetOp.getSourceOperand().get()) &&
590  matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
591  subsetOp))
592  return true;
593 
594  return false;
595 }
596 
597 /// Given sets of uses and writes, return true if there is a RaW conflict under
598 /// the assumption that all given reads/writes alias the same buffer and that
599 /// all given writes bufferize inplace.
600 ///
601 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
602 /// the result of a definition W1. But because of bufferization decisions, R
603 /// actually reads another definition W2.
604 static bool
606  const DenseSet<OpOperand *> &usesWrite,
607  const DominanceInfo &domInfo,
608  OneShotAnalysisState &state) {
609  const BufferizationOptions &options = state.getOptions();
610 
611  // Before going through the main RaW analysis, find cases where a buffer must
612  // be privatized due to parallelism. If the result of a write is never read,
613  // privatization is not necessary (and large parts of the IR are likely dead).
614  if (!usesRead.empty()) {
615  for (OpOperand *uConflictingWrite : usesWrite) {
616  // Find the allocation point or last write (definition) of the buffer.
617  // Note: In contrast to `findDefinitions`, this also returns results of
618  // ops that do not bufferize to memory write when no other definition
619  // could be found. E.g., "bufferization.alloc_tensor" would be included,
620  // even though that op just bufferizes to an allocation but does define
621  // the contents of the buffer.
622  SetVector<Value> definitionsOrLeaves =
623  state.findValueInReverseUseDefChain(
624  uConflictingWrite->get(),
625  [&](Value v) { return state.bufferizesToMemoryWrite(v); });
626  assert(!definitionsOrLeaves.empty() &&
627  "expected at least one definition or leaf");
628 
629  // The writing op must bufferize out-of-place if the definition is in a
630  // different parallel region than this write.
631  for (Value def : definitionsOrLeaves) {
632  if (getParallelRegion(def.getParentRegion(), options) !=
633  getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
634  options)) {
635  LLVM_DEBUG(
636  llvm::dbgs()
637  << "\n- bufferizes out-of-place due to parallel region:\n");
638  LLVM_DEBUG(llvm::dbgs()
639  << " unConflictingWrite = operand "
640  << uConflictingWrite->getOperandNumber() << " of "
641  << *uConflictingWrite->getOwner() << "\n");
642  return true;
643  }
644  }
645  }
646  }
647 
648  for (OpOperand *uRead : usesRead) {
649  Operation *readingOp = uRead->getOwner();
650  LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
651  LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
652  << " of " << *readingOp << "\n");
653 
654  // Find the definition of uRead by following the SSA use-def chain.
655  // E.g.:
656  //
657  // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
658  // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
659  // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
660  //
661  // In the above example, if uRead is the OpOperand of reading_op, the
662  // definition is %0. Note that operations that create an alias but do not
663  // bufferize to a memory write (such as ExtractSliceOp) are skipped.
664  const SetVector<Value> &definitions =
665  state.findDefinitionsCached(uRead->get());
666  if (definitions.empty()) {
667  // Fast path: No conflict if there are no definitions.
668  LLVM_DEBUG(llvm::dbgs()
669  << " no conflict: read value has no definitions\n");
670  continue;
671  }
672 
673  // Look for conflicting memory writes. Potential conflicts are writes to an
674  // alias that have been decided to bufferize inplace.
675  for (OpOperand *uConflictingWrite : usesWrite) {
676  LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
677  << uConflictingWrite->getOperandNumber() << " of "
678  << *uConflictingWrite->getOwner() << "\n");
679 
680  // Check if op dominance can be used to rule out read-after-write
681  // conflicts.
682  bool useDominance =
683  canUseOpDominance(uRead, uConflictingWrite, definitions, state);
684  LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
685 
686  // Throughout this loop, check for multiple requirements that have to be
687  // met for uConflictingWrite to be an actual conflict.
688  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
689 
690  // Inside of repetitive regions, ops may be executed multiple times and op
691  // dominance cannot be used to rule out conflicts.
692  if (useDominance) {
693  // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
694  // the write is not visible when reading.
695  //
696  // Note: If ops are executed multiple times (e.g., because they are
697  // inside a loop), there may be no meaningful `happensBefore`
698  // relationship.
699  if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
700  LLVM_DEBUG(llvm::dbgs()
701  << " no conflict: read happens before write\n");
702  continue;
703  }
704 
705  // No conflict if the reading use equals the use of the conflicting
706  // write. A use cannot conflict with itself.
707  //
708  // Note: Just being the same op is not enough. It has to be the same
709  // use.
710  // Note: If the op is executed multiple times (e.g., because it is
711  // inside a loop), it may be conflicting with itself.
712  if (uConflictingWrite == uRead) {
713  LLVM_DEBUG(llvm::dbgs()
714  << " no conflict: read and write are same use\n");
715  continue;
716  }
717 
718  // Ops are not conflicting if they are in mutually exclusive regions.
719  //
720  // Note: If ops are executed multiple times (e.g., because they are
721  // inside a loop), mutually exclusive regions may be executed
722  // multiple times.
723  if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) {
724  LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
725  "mutually exclusive regions\n");
726  continue;
727  }
728  }
729 
730  // Two equivalent operands of the same op are not conflicting if the op
731  // bufferizes to element-wise access. I.e., all loads at a position happen
732  // before all stores to the same position.
733  if (conflictingWritingOp == readingOp) {
734  if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
735  if (bufferizableOp.bufferizesToElementwiseAccess(
736  state, {uRead, uConflictingWrite})) {
738  state, uRead->get(), uConflictingWrite->get()) ||
740  state, uConflictingWrite->get(), uRead->get())) {
741  LLVM_DEBUG(
742  llvm::dbgs()
743  << " no conflict: op bufferizes to element-wise access\n");
744  continue;
745  }
746  }
747  }
748  }
749 
750  // No conflict if the operands are non-conflicting subsets.
751  if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
752  LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
753  continue;
754  }
755 
756  // No conflict if the op interface says so.
757  if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
758  if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
759  LLVM_DEBUG(llvm::dbgs()
760  << " no conflict: op interace of reading op says 'no'\n");
761  continue;
762  }
763  }
764 
765  if (conflictingWritingOp != readingOp) {
766  if (auto bufferizableOp =
767  options.dynCastBufferizableOp(conflictingWritingOp)) {
768  if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
769  state)) {
770  LLVM_DEBUG(
771  llvm::dbgs()
772  << " no conflict: op interace of writing op says 'no'\n");
773  continue;
774  }
775  }
776  }
777 
778  // Check all possible definitions.
779  for (Value definition : definitions) {
780  LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
781 
782  // No conflict if the conflicting write happens before the definition.
783  if (Operation *defOp = definition.getDefiningOp()) {
784  if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
785  // conflictingWritingOp happens before defOp. No conflict.
786  LLVM_DEBUG(llvm::dbgs()
787  << " no conflict: write happens before definition\n");
788  continue;
789  }
790  // No conflict if conflictingWritingOp is contained in defOp.
791  if (defOp->isProperAncestor(conflictingWritingOp)) {
792  LLVM_DEBUG(
793  llvm::dbgs()
794  << " no conflict: write is contained in definition\n");
795  continue;
796  }
797  } else {
798  auto bbArg = cast<BlockArgument>(definition);
799  Block *block = bbArg.getOwner();
800  if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
801  LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
802  "and write happens outside of block\n");
803  // conflictingWritingOp happens outside of the block. No
804  // conflict.
805  continue;
806  }
807  }
808 
809  // No conflict if the conflicting write and the definition are the same
810  // use.
811  AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
812  if (aliases.getNumAliases() == 1 &&
813  aliases.getAliases()[0].value == definition) {
814  LLVM_DEBUG(llvm::dbgs()
815  << " no conflict: definition and write are same\n");
816  continue;
817  }
818 
819  // All requirements are met. Conflict found!
820 
821  if (options.printConflicts)
822  annotateConflict(uRead, uConflictingWrite, definition);
823  LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
824  return true;
825  }
826  }
827  }
828 
829  return false;
830 }
831 
832 // Helper function to iterate on aliases of `root` and capture the writes.
834  const OneShotAnalysisState &state) {
835  state.applyOnAliases(root, [&](Value alias) {
836  for (auto &use : alias.getUses())
837  // Inplace write to a value that aliases root.
838  if (isInplaceMemoryWrite(use, state))
839  res.insert(&use);
840  });
841 }
842 
843 // Helper function to iterate on aliases of `root` and capture the reads.
845  const OneShotAnalysisState &state) {
846  state.applyOnAliases(root, [&](Value alias) {
847  for (auto &use : alias.getUses()) {
848  // Read of a value that aliases root.
849  if (state.bufferizesToMemoryRead(use)) {
850  res.insert(&use);
851  continue;
852  }
853 
854  // Read of a dependent value in the SSA use-def chain. E.g.:
855  //
856  // %0 = ...
857  // %1 = tensor.extract_slice %0 {not_analyzed_yet}
858  // "read"(%1)
859  //
860  // In the above example, getAliasingReads(%0) includes the first OpOperand
861  // of the tensor.extract_slice op. The extract_slice itself does not read
862  // but its aliasing result is eventually fed into an op that does.
863  //
864  // Note: This is considered a "read" only if the use does not bufferize to
865  // a memory write. (We already ruled out memory reads. In case of a memory
866  // write, the buffer would be entirely overwritten; in the above example
867  // there would then be no flow of data from the extract_slice operand to
868  // its result's uses.)
869  if (!state.bufferizesToMemoryWrite(use)) {
870  AliasingValueList aliases = state.getAliasingValues(use);
871  if (llvm::any_of(aliases, [&](AliasingValue a) {
872  return state.isValueRead(a.value);
873  }))
874  res.insert(&use);
875  }
876  }
877  });
878 }
879 
880 /// Return true if bufferizing `operand` inplace would create a conflict. A read
881 /// R and a write W of the same alias set is a conflict if inplace bufferization
882 /// of W changes the value read by R to a value different from the one that
883 /// would be expected by tracing back R's origin through SSA use-def chains.
884 /// A conflict can only be introduced by a new alias and/or an inplace
885 /// bufferization decision.
886 ///
887 /// Example:
888 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
889 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
890 /// %e = tensor.extract_slice %1
891 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
892 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
893 ///
894 /// In the above example, the two TransferWriteOps have already been decided to
895 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
896 /// conflict because:
897 /// * According to SSA use-def chains, we expect to read the result of %1.
898 /// * However, adding an alias {%0, %t} would mean that the second
899 /// TransferWriteOp overwrites the result of the first one. Therefore, the
900 /// TransferReadOp would no longer be reading the result of %1.
901 ///
902 /// If `checkConsistencyOnly` is true, this function checks if there is a
903 /// read-after-write conflict without bufferizing `operand` inplace. This would
904 /// indicate a problem with the current inplace bufferization decisions.
905 ///
906 /// Note: If `checkConsistencyOnly`, this function may be called with a null
907 /// OpResult. In that case, only the consistency of bufferization decisions
908 /// involving aliases of the given OpOperand are checked.
910  OpOperand &operand, const DominanceInfo &domInfo,
911  OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
912  // Collect reads and writes of all aliases of OpOperand and OpResult.
913  DenseSet<OpOperand *> usesRead, usesWrite;
914  getAliasingReads(usesRead, operand.get(), state);
915  getAliasingInplaceWrites(usesWrite, operand.get(), state);
916  for (AliasingValue alias : state.getAliasingValues(operand)) {
917  getAliasingReads(usesRead, alias.value, state);
918  getAliasingInplaceWrites(usesWrite, alias.value, state);
919  }
920  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
921  usesWrite.insert(&operand);
922 
923  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
924 }
925 
926 /// Annotate IR with details about the detected non-writability conflict.
927 static void annotateNonWritableTensor(Value value) {
928  static int64_t counter = 0;
929  OpBuilder b(value.getContext());
930  std::string id = "W_" + std::to_string(counter++);
931  if (auto opResult = dyn_cast<OpResult>(value)) {
932  std::string attr = id + "[NOT-WRITABLE: result " +
933  std::to_string(opResult.getResultNumber()) + "]";
934  opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
935  } else {
936  auto bbArg = cast<BlockArgument>(value);
937  std::string attr = id + "[NOT-WRITABLE: bbArg " +
938  std::to_string(bbArg.getArgNumber()) + "]";
939  bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
940  }
941 }
942 
943 /// Return true if bufferizing `operand` inplace would create a write to a
944 /// non-writable buffer.
945 static bool
947  OneShotAnalysisState &state,
948  bool checkConsistencyOnly = false) {
949  bool foundWrite =
950  !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
951 
952  if (!foundWrite) {
953  // Collect writes of all aliases of OpOperand and OpResult.
954  DenseSet<OpOperand *> usesWrite;
955  getAliasingInplaceWrites(usesWrite, operand.get(), state);
956  for (AliasingValue alias : state.getAliasingValues(operand))
957  getAliasingInplaceWrites(usesWrite, alias.value, state);
958  foundWrite = !usesWrite.empty();
959  }
960 
961  if (!foundWrite)
962  return false;
963 
964  // Look for a read-only tensor among all aliases.
965  bool foundReadOnly = false;
966  auto checkReadOnly = [&](Value v) {
967  if (!state.isWritable(v)) {
968  foundReadOnly = true;
969  if (state.getOptions().printConflicts)
971  }
972  };
973  state.applyOnAliases(operand.get(), checkReadOnly);
974  for (AliasingValue alias : state.getAliasingValues(operand))
975  state.applyOnAliases(alias.value, checkReadOnly);
976  if (foundReadOnly) {
977  LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
978  return true;
979  }
980 
981  return false;
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // Bufferization analyses.
986 //===----------------------------------------------------------------------===//
987 
988 // Find the values that define the contents of the given value.
990 OneShotAnalysisState::findDefinitionsCached(Value value) {
991  if (!cachedDefinitions.count(value))
992  cachedDefinitions[value] = findDefinitions(value);
993  return cachedDefinitions[value];
994 }
995 
998  cachedDefinitions.clear();
999 }
1000 
1001 /// Determine if `operand` can be bufferized in-place.
1002 static LogicalResult
1004  const DominanceInfo &domInfo) {
1005  LLVM_DEBUG(
1006  llvm::dbgs() << "//===-------------------------------------------===//\n"
1007  << "Analyzing operand #" << operand.getOperandNumber()
1008  << " of " << *operand.getOwner() << "\n");
1009 
1010  bool foundInterference =
1011  wouldCreateWriteToNonWritableBuffer(operand, state) ||
1012  wouldCreateReadAfterWriteInterference(operand, domInfo, state);
1013 
1014  if (foundInterference)
1015  state.bufferizeOutOfPlace(operand);
1016  else
1017  state.bufferizeInPlace(operand);
1018 
1019  LLVM_DEBUG(llvm::dbgs()
1020  << "//===-------------------------------------------===//\n");
1021  return success();
1022 }
1023 
1026  const DominanceInfo &domInfo) {
1027  for (OpOperand &opOperand : op->getOpOperands())
1028  if (isa<TensorType>(opOperand.get().getType()))
1029  if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
1030  return failure();
1031  return success();
1032 }
1033 
1034 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1036  OneShotAnalysisState &state) {
1037  for (Operation *op : ops) {
1038  if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1039  for (OpResult opResult : op->getOpResults()) {
1040  if (!isa<TensorType>(opResult.getType()))
1041  continue;
1042  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1043  if (aliases.getNumAliases() == 0)
1044  // Nothing to do if there are no aliasing OpOperands.
1045  continue;
1046 
1047  Value firstOperand = aliases.begin()->opOperand->get();
1048  bool allEquivalent = true;
1049  for (AliasingOpOperand alias : aliases) {
1050  bool isEquiv = alias.relation == BufferRelation::Equivalent;
1051  bool isInPlace = state.isInPlace(*alias.opOperand);
1052  Value operand = alias.opOperand->get();
1053  if (isEquiv && isInPlace && alias.isDefinite) {
1054  // Found a definite, equivalent alias. Merge equivalence sets.
1055  // There can only be one definite alias, so we can stop here.
1056  state.unionEquivalenceClasses(opResult, operand);
1057  allEquivalent = false;
1058  break;
1059  }
1060  if (!isEquiv || !isInPlace)
1061  allEquivalent = false;
1062  if (!state.areEquivalentBufferizedValues(operand, firstOperand))
1063  allEquivalent = false;
1064  }
1065 
1066  // If all "maybe" aliases are equivalent and the OpResult is not a new
1067  // allocation, it is a definite, equivalent alias. E.g.:
1068  //
1069  // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1070  // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1071  // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1072  // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1073  //
1074  // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1075  // classes of %r, %t0 and %t1.
1076  if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
1077  state.unionEquivalenceClasses(opResult, firstOperand);
1078  }
1079  }
1080  }
1081 }
1082 
1083 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1084 /// in `op`.
1086  // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1088  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
1089  // No tensors => no buffers.
1090  if (none_of(op->getResultTypes(), isaTensor))
1091  return;
1092  ops.push_back(op);
1093  });
1094 
1095  equivalenceAnalysis(ops, state);
1096 }
1097 
1098 /// "Bottom-up from terminators" heuristic.
1101  const OneShotAnalysisState &state) {
1102  SetVector<Operation *> traversedOps;
1103 
1104  // Find region terminators.
1105  op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
1106  if (!traversedOps.insert(term))
1107  return;
1108  // Follow the reverse SSA use-def chain from each yielded value as long as
1109  // we stay within the same region.
1110  SmallVector<OpResult> worklist;
1111  for (Value v : term->getOperands()) {
1112  if (!isa<TensorType>(v.getType()))
1113  continue;
1114  auto opResult = dyn_cast<OpResult>(v);
1115  if (!opResult)
1116  continue;
1117  worklist.push_back(opResult);
1118  }
1119  while (!worklist.empty()) {
1120  OpResult opResult = worklist.pop_back_val();
1121  Operation *defOp = opResult.getDefiningOp();
1122  if (!traversedOps.insert(defOp))
1123  continue;
1124  if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
1125  continue;
1126  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1127  for (auto alias : aliases) {
1128  Value v = alias.opOperand->get();
1129  if (!isa<TensorType>(v.getType()))
1130  continue;
1131  auto opResult = dyn_cast<OpResult>(v);
1132  if (!opResult)
1133  continue;
1134  worklist.push_back(opResult);
1135  }
1136  }
1137  });
1138 
1139  // Analyze traversed ops, then all remaining ops.
1140  SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1142  if (!traversedOps.contains(op) && hasTensorSemantics(op))
1143  result.push_back(op);
1144  });
1145  return result;
1146 }
1147 
1149  const DominanceInfo &domInfo) {
1152 
1153  SmallVector<Operation *> orderedOps;
1154  if (heuristic ==
1156  orderedOps = bottomUpFromTerminatorsHeuristic(op, *this);
1157  } else {
1158  op->walk([&](Operation *op) {
1159  // No tensors => no buffers.
1160  if (!hasTensorSemantics(op))
1161  return;
1162  orderedOps.push_back(op);
1163  });
1164  switch (heuristic) {
1166  // Default: Walk ops in reverse for better interference analysis.
1167  std::reverse(orderedOps.begin(), orderedOps.end());
1168  break;
1169  }
1171  // Ops are already sorted top-down in `orderedOps`.
1172  break;
1173  }
1175  assert(getOptions().analysisFuzzerSeed &&
1176  "expected that fuzzer seed it set");
1177  // This is a fuzzer. For testing purposes only. Randomize the order in
1178  // which operations are analyzed. The bufferization quality is likely
1179  // worse, but we want to make sure that no assertions are triggered
1180  // anywhere.
1181  std::mt19937 g(getOptions().analysisFuzzerSeed);
1182  llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);
1183  break;
1184  }
1185  default: {
1186  llvm_unreachable("unsupported heuristic");
1187  }
1188  }
1189  }
1190 
1191  // Analyze ops in the computed order.
1192  for (Operation *op : orderedOps)
1193  if (failed(analyzeSingleOp(op, domInfo)))
1194  return failure();
1195 
1196  equivalenceAnalysis(op, *this);
1197  return success();
1198 }
1199 
1200 /// Perform various checks on the input IR to see if it contains IR constructs
1201 /// that are unsupported by One-Shot Bufferize.
1202 static LogicalResult
1204  OneShotAnalysisState &state) {
1205  const BufferizationOptions &options = state.getOptions();
1206 
1207  // Note: This walk cannot be combined with the one below because interface
1208  // methods of invalid/unsupported ops may be called during the second walk.
1209  // (On ops different from `op`.)
1210  WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
1211  // Skip ops that are not in the filter.
1212  if (!options.isOpAllowed(op.getOperation()))
1213  return WalkResult::advance();
1214 
1215  // Check for unsupported unstructured control flow.
1216  if (!op.supportsUnstructuredControlFlow()) {
1217  for (Region &r : op->getRegions()) {
1218  if (r.getBlocks().size() > 1) {
1219  op->emitOpError("op or BufferizableOpInterface implementation does "
1220  "not support unstructured control flow, but at least "
1221  "one region has multiple blocks");
1222  return WalkResult::interrupt();
1223  }
1224  }
1225  }
1226 
1227  return WalkResult::advance();
1228  });
1229  if (walkResult.wasInterrupted())
1230  return failure();
1231 
1232  walkResult = op->walk([&](BufferizableOpInterface op) {
1233  // Skip ops that are not in the filter.
1234  if (!options.isOpAllowed(op.getOperation()))
1235  return WalkResult::advance();
1236 
1237  // Input IR may not contain any ToTensorOps without the "restrict"
1238  // attribute. Such tensors may alias any other tensor, which is currently
1239  // not handled in the analysis.
1240  if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
1241  if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1242  op->emitOpError("to_tensor ops without `restrict` are not supported by "
1243  "One-Shot Analysis");
1244  return WalkResult::interrupt();
1245  }
1246  }
1247 
1248  for (OpOperand &opOperand : op->getOpOperands()) {
1249  if (isa<TensorType>(opOperand.get().getType())) {
1251  opOperand, domInfo, state,
1252  /*checkConsistencyOnly=*/true)) {
1253  // This error can happen if certain "mustBufferizeInPlace" interface
1254  // methods are implemented incorrectly, such that the IR already has
1255  // a RaW conflict before making any bufferization decisions. It can
1256  // also happen if the bufferization.materialize_in_destination is used
1257  // in such a way that a RaW conflict is not avoidable.
1258  op->emitOpError("not bufferizable under the given constraints: "
1259  "cannot avoid RaW conflict");
1260  return WalkResult::interrupt();
1261  }
1262 
1263  if (state.isInPlace(opOperand) &&
1265  opOperand, state, /*checkConsistencyOnly=*/true)) {
1266  op->emitOpError("not bufferizable under the given constraints: would "
1267  "write to read-only buffer");
1268  return WalkResult::interrupt();
1269  }
1270  }
1271  }
1272 
1273  return WalkResult::advance();
1274  });
1275 
1276  return success(!walkResult.wasInterrupted());
1277 }
1278 
1279 /// Annotate the IR with the result of the analysis. For testing/debugging only.
1280 static void
1282  const OneShotAnalysisState &state) {
1283  // Add __inplace_operands_attr__.
1284  op->walk([&](Operation *op) {
1285  for (OpOperand &opOperand : op->getOpOperands())
1286  if (isa<TensorType>(opOperand.get().getType()))
1287  setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
1288  });
1289 }
1290 
1292  const OneShotAnalysisState &state) {
1293  AsmState asmState(op);
1294  Builder b(op->getContext());
1295  // Helper function to build an array attribute of aliasing SSA value strings.
1296  auto buildAliasesArray = [&](Value v) {
1297  SmallVector<Attribute> aliases;
1298  state.applyOnAliases(v, [&](Value alias) {
1299  std::string buffer;
1300  llvm::raw_string_ostream stream(buffer);
1301  alias.printAsOperand(stream, asmState);
1302  aliases.push_back(b.getStringAttr(stream.str()));
1303  });
1304  return b.getArrayAttr(aliases);
1305  };
1306 
1307  op->walk([&](Operation *op) {
1308  // Build alias set array for every OpResult.
1309  SmallVector<Attribute> opResultAliasSets;
1310  for (OpResult opResult : op->getOpResults()) {
1311  if (llvm::isa<TensorType>(opResult.getType())) {
1312  opResultAliasSets.push_back(buildAliasesArray(opResult));
1313  }
1314  }
1315  if (!opResultAliasSets.empty())
1316  op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1317 
1318  // Build alias set array for every BlockArgument.
1319  SmallVector<Attribute> regionAliasSets;
1320  bool hasTensorBbArg = false;
1321  for (Region &r : op->getRegions()) {
1322  SmallVector<Attribute> blockAliasSets;
1323  for (Block &block : r.getBlocks()) {
1324  SmallVector<Attribute> bbArgAliasSets;
1325  for (BlockArgument bbArg : block.getArguments()) {
1326  if (llvm::isa<TensorType>(bbArg.getType())) {
1327  bbArgAliasSets.push_back(buildAliasesArray(bbArg));
1328  hasTensorBbArg = true;
1329  }
1330  }
1331  blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1332  }
1333  regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1334  }
1335  if (hasTensorBbArg)
1336  op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1337  });
1338 }
1339 
1341  OneShotAnalysisState &state,
1342  BufferizationStatistics *statistics) {
1343  DominanceInfo domInfo(op);
1344  const OneShotBufferizationOptions &options = state.getOptions();
1345 
1346  if (failed(checkPreBufferizationAssumptions(op, domInfo, state)))
1347  return failure();
1348 
1349  // If the analysis fails, just return.
1350  if (failed(state.analyzeOp(op, domInfo)))
1351  return failure();
1352 
1353  if (statistics) {
1354  statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1355  statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1356  }
1357 
1358  bool failedAnalysis = false;
1359 
1360  // Gather some extra analysis data.
1361  state.gatherUndefinedTensorUses(op);
1362 
1363  // Analysis verification: After setting up alias/equivalence sets, each op
1364  // can check for expected invariants/limitations and fail the analysis if
1365  // necessary.
1366  op->walk([&](Operation *op) {
1367  if (BufferizableOpInterface bufferizableOp =
1368  options.dynCastBufferizableOp(op))
1369  failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
1370  });
1371 
1372  // Annotate operations if we only want to report the analysis.
1373  if (options.testAnalysisOnly)
1375  if (options.dumpAliasSets)
1376  annotateOpsWithAliasSets(op, state);
1377 
1378  return success(!failedAnalysis);
1379 }
1380 
1384  BufferizationStatistics *statistics) {
1385  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1386  "invalid combination of bufferization flags");
1387  if (!options.copyBeforeWrite) {
1388  // If a buffer is copied before every write, no analysis is needed.
1389  if (failed(insertTensorCopies(op, options, statistics)))
1390  return failure();
1391  }
1392  if (options.testAnalysisOnly)
1393  return success();
1394  return bufferizeOp(op, options, statistics);
1395 }
static bool hasReadAfterWriteInterference(const DenseSet< OpOperand * > &usesRead, const DenseSet< OpOperand * > &usesWrite, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Given sets of uses and writes, return true if there is a RaW conflict under the assumption that all g...
static void getAliasingReads(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static void equivalenceAnalysis(SmallVector< Operation * > &ops, OneShotAnalysisState &state)
Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace)
Mark whether OpOperand will be bufferized inplace.
static SmallVector< Operation * > bottomUpFromTerminatorsHeuristic(Operation *op, const OneShotAnalysisState &state)
"Bottom-up from terminators" heuristic.
constexpr StringLiteral kInPlaceOperandsAttrName
Attribute marker to specify op operands that bufferize in-place.
static bool isaTensor(Type t)
static void annotateNonWritableTensor(Value value)
Annotate IR with details about the detected non-writability conflict.
static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...
static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo)
Determine if operand can be bufferized in-place.
static bool isReachable(Block *from, Block *to, ArrayRef< Block * > except)
static bool matchesInsertDestination(const AnalysisState &state, Value value, SubsetInsertionOpInterface subsetOp)
Return "true" if value is originating from a subset that is equivalent to the subset that subsetOp in...
constexpr StringLiteral kOpResultAliasSetAttrName
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, Value start, Value other)
Return 'true' if a tensor that is equivalent to other can be found in the reverse use-def chain of st...
static bool happensBefore(Operation *a, Operation *b, const DominanceInfo &domInfo)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
static bool wouldCreateWriteToNonWritableBuffer(OpOperand &operand, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a write to a non-writable buffer.
static void annotateOpsWithAliasSets(Operation *op, const OneShotAnalysisState &state)
static LogicalResult checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Perform various checks on the input IR to see if it contains IR constructs that are unsupported by On...
static void annotateOpsWithBufferizationMarkers(Operation *op, const OneShotAnalysisState &state)
Annotate the IR with the result of the analysis. For testing/debugging only.
static bool wouldCreateReadAfterWriteInterference(OpOperand &operand, const DominanceInfo &domInfo, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a conflict.
constexpr StringLiteral kBbArgAliasSetAttrName
static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...
static void getAliasingInplaceWrites(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static bool areNonConflictingSubsets(OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state)
Return "true" if the given "read" and potentially conflicting "write" are not conflicting due to thei...
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value definition)
Annotate IR with details about the detected RaW conflict.
static bool isInplaceMemoryWrite(OpOperand &opOperand, const OneShotAnalysisState &state)
Return true if opOperand has been decided to bufferize in-place.
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
Base class for generic analysis states.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:533
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:73
SuccessorRange getSuccessors()
Definition: Block.h:264
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
A class for computing basic dominance information.
Definition: Dominance.h:136
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:149
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class helps build Operations.
Definition: Builders.h:209
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
result_range getOpResults()
Definition: Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:128
Type getType() const
Return the type of this value.
Definition: Value.h:125
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult skip()
Definition: Visitors.h:53
static WalkResult advance()
Definition: Visitors.h:52
static WalkResult interrupt()
Definition: Visitors.h:51
AnalysisState provides a variety of helper functions for dealing with tensor values.
AliasingValueList getAliasingValues(OpOperand &opOperand) const
Determine which Value will alias with opOperand if the op is bufferized in place.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
SetVector< Value > findDefinitions(Value value) const
Find the values that may define the contents of the given value at runtime.
virtual ~Extension()
Base virtual destructor.
State for analysis-enabled bufferization.
void bufferizeOutOfPlace(OpOperand &operand)
Mark the given OpOperand as out-of-place.
bool isWritable(Value value) const
Return true if the buffer of the given tensor value is writable.
const SetVector< Value > & findDefinitionsCached(Value value)
Find the definitions of the given tensor value or retrieve them from the cache.
bool isInPlace(OpOperand &opOperand) const override
Return true if the given OpResult has been decided to bufferize inplace.
LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo)
Analyze the given op and its nested ops.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
void unionEquivalenceClasses(Value v1, Value v2)
Union the equivalence classes of v1 and v2.
void gatherUndefinedTensorUses(Operation *op)
Find all tensor values in the given operation that have undefined contents and store them in undefine...
void resetCache() override
Reset cached data structures.
LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo)
Analyze a single op (without nested ops).
void applyOnEquivalenceClass(Value v, function_ref< void(Value)> fun) const
Apply fun to all the members of the equivalence class of v.
bool hasUndefinedContents(OpOperand *opOperand) const override
Return true if the given tensor has undefined contents.
void bufferizeInPlace(OpOperand &operand)
Mark the given OpOperand as in-place and merge the results' and operand's aliasing sets.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
OneShotAnalysisState(Operation *op, const OneShotBufferizationOptions &options)
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
void unionAliasSets(Value v1, Value v2)
Union the alias sets of v1 and v2.
void createAliasInfoEntry(Value v)
Add a new entry for v in the aliasInfo and equivalentInfo.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:439
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
Region * getParallelRegion(Region *region, const BufferizationOptions &options)
If region is a parallel region, return region.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This iterator enumerates elements in "reverse" order.
Definition: Iterators.h:29
Options for BufferizableOpInterface-based bufferization.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
Traversal parameters for findValueInReverseUseDefChain.
bool alwaysIncludeLeaves
Specifies if leaves (that do not have further OpOperands to follow) should be returned even if they d...
bool followSameTypeOrCastsOnly
Specifies whether OpOperands with a different type that are not the result of a CastOpInterface op sh...
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.