MLIR  20.0.0git
OneShotAnalysis.cpp
Go to the documentation of this file.
1 //===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // One-Shot Analysis analyzes function bodies. By default, function boundaries
10 // (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
11 // OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
12 // simple call graphs without loops.
13 //
14 // One-Shot Bufferize consists of three phases.
15 //
16 // 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
17 // without inserting buffer copies. The analysis queries op bufferization
18 // semantics via `BufferizableOpInterface`.
19 // 2. Insert copies for OpOperands that were decided to bufferize out-of-place
20 // in tensor land during `TensorCopyInsertion`.
21 // 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
22 //
23 // This file contains only the analysis. For convenience, this file also
24 // contains a helper function `runOneShotBufferize` that analyzes an op (and its
25 // nested ops) and then bufferizes it.
26 //
27 // Inplace bufferization decisions are passed from the analysis to the
28 // `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
29 // debugging purposes with `testAnalysisOnly`.
30 //
31 // Ops that do not implement `BufferizableOpInterface` can be analyzed but are
32 // treated conservatively. E.g., the analysis has to assume that their tensor
33 // OpOperands bufferize to memory writes. While such ops can be analyzed, they
34 // are not bufferized and remain in the IR. to_tensor and to_memref ops are
35 // inserted at the bufferization boundary.
36 //
37 // This analysis caters to high-performance codegen where buffer reuse is deemed
38 // critical: the analysis should fail if the bufferized form of the function
39 // needs to return a buffer, unless `allowReturnAllocs` is enabled.
40 
42 
43 #include <optional>
44 #include <random>
45 
52 #include "mlir/IR/AsmState.h"
53 #include "mlir/IR/Dominance.h"
54 #include "mlir/IR/Iterators.h"
55 #include "mlir/IR/Operation.h"
56 #include "mlir/IR/TypeUtilities.h"
59 #include "llvm/ADT/DenseSet.h"
60 #include "llvm/ADT/SetVector.h"
61 
63 
64 // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
65 // output.
66 #define DEBUG_TYPE "one-shot-analysis"
67 
68 using namespace mlir;
69 using namespace mlir::bufferization;
70 
71 static bool isaTensor(Type t) { return isa<TensorType>(t); }
72 
73 //===----------------------------------------------------------------------===//
74 // Bufferization-specific attribute manipulation.
75 // These are for testing and debugging only. Bufferization information is stored
76 // in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
77 // annotated with the results of the analysis, so that they can be checked in
78 // tests.
79 //===----------------------------------------------------------------------===//
80 
81 /// Attribute marker to specify op operands that bufferize in-place.
82 constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
83 
84 constexpr StringLiteral kOpResultAliasSetAttrName =
85  "__opresult_alias_set_attr__";
86 
87 constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
88 
89 /// Mark whether OpOperand will be bufferized inplace.
90 static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
91  Operation *op = opOperand.getOwner();
92  SmallVector<StringRef> inPlaceVector;
93  if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
94  inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
95  cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
96  } else {
97  inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
98  for (OpOperand &opOperand : op->getOpOperands())
99  if (isa<TensorType>(opOperand.get().getType()))
100  inPlaceVector[opOperand.getOperandNumber()] = "false";
101  }
102  inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
104  OpBuilder(op).getStrArrayAttr(inPlaceVector));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // OneShotAnalysisState
109 //===----------------------------------------------------------------------===//
110 
114  // Set up alias sets.
115  op->walk([&](Operation *op) {
116  for (Value v : op->getResults())
117  if (isa<TensorType>(v.getType()))
119  for (Region &r : op->getRegions())
120  for (Block &b : r.getBlocks())
121  for (auto bbArg : b.getArguments())
122  if (isa<TensorType>(bbArg.getType()))
123  createAliasInfoEntry(bbArg);
124  });
125 
126  // Mark OpOperands in-place that must bufferize in-place.
127  op->walk([&](BufferizableOpInterface bufferizableOp) {
128  if (!options.isOpAllowed(bufferizableOp))
129  return WalkResult::skip();
130  for (OpOperand &opOperand : bufferizableOp->getOpOperands())
131  if (isa<TensorType>(opOperand.get().getType()))
132  if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
133  bufferizeInPlace(opOperand);
134  return WalkResult::advance();
135  });
136 }
137 
139  Value v, function_ref<void(Value)> fun) const {
140  auto leaderIt = equivalentInfo.findLeader(v);
141  for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
142  ++mit) {
143  fun(*mit);
144  }
145 }
146 
148  function_ref<void(Value)> fun) const {
149  auto leaderIt = aliasInfo.findLeader(v);
150  for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
151  fun(*mit);
152  }
153 }
154 
156  Value v2) const {
157  return equivalentInfo.isEquivalent(v1, v2);
158 }
159 
161  Value v2) const {
162  return aliasInfo.isEquivalent(v1, v2);
163 }
164 
166  if (inplaceBufferized.contains(&operand))
167  return;
168  inplaceBufferized.insert(&operand);
169  for (AliasingValue alias : getAliasingValues(operand))
170  aliasInfo.unionSets(alias.value, operand.get());
171  ++statNumTensorInPlace;
172 }
173 
175  assert(!inplaceBufferized.contains(&operand) &&
176  "OpOperand was already decided to bufferize inplace");
177  ++statNumTensorOutOfPlace;
178 }
179 
181  aliasInfo.insert(v);
182  equivalentInfo.insert(v);
183 }
184 
186  op->walk([&](Operation *op) {
187  // Skip unknown ops.
188  auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
189  if (!bufferizableOp)
190  return WalkResult::skip();
191 
192  // Check all tensor OpResults.
193  for (OpResult opResult : op->getOpResults()) {
194  if (!isa<TensorType>(opResult.getType()))
195  continue;
196 
197  // If there is no preceding definition, the tensor contents are
198  // undefined.
199  if (findDefinitionsCached(opResult).empty())
200  for (OpOperand &use : opResult.getUses())
201  undefinedTensorUses.insert(&use);
202  }
203 
204  return WalkResult::advance();
205  });
206 }
207 
209  return undefinedTensorUses.contains(opOperand);
210 }
211 
213  return inplaceBufferized.contains(&opOperand);
214 }
215 
217  bool isWritten = false;
218  applyOnAliases(value, [&](Value val) {
219  for (OpOperand &use : val.getUses())
220  if (isInPlace(use) && bufferizesToMemoryWrite(use))
221  isWritten = true;
222  });
223  return isWritten;
224 }
225 
227  // TODO: Out-of-place bufferized value could be considered writable.
228  // Query BufferizableOpInterface to see if the BlockArgument is writable.
229  if (auto bufferizableOp =
230  getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
231  return bufferizableOp.isWritable(value, *this);
232 
233  // Not a bufferizable op: The conservative answer is "not writable".
234  return false;
235 }
236 
238  aliasInfo.unionSets(v1, v2);
239 }
240 
242  equivalentInfo.unionSets(v1, v2);
243 }
244 
246 
247 //===----------------------------------------------------------------------===//
248 // Bufferization-specific alias analysis.
249 //===----------------------------------------------------------------------===//
250 
251 /// Return true if opOperand has been decided to bufferize in-place.
252 static bool isInplaceMemoryWrite(OpOperand &opOperand,
253  const OneShotAnalysisState &state) {
254  // OpOperands that do not bufferize to a memory write do not write in-place.
255  if (!state.bufferizesToMemoryWrite(opOperand))
256  return false;
257  // Check current bufferization decisions.
258  return state.isInPlace(opOperand);
259 }
260 
261 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
262 /// properly dominates `b` and `b` is not inside `a`.
263 static bool happensBefore(Operation *a, Operation *b,
264  const DominanceInfo &domInfo) {
265  do {
266  // TODO: Instead of isProperAncestor + properlyDominates, we should use
267  // properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
268  if (a->isProperAncestor(b))
269  return false;
270  if (domInfo.properlyDominates(a, b))
271  return true;
272  } while ((a = a->getParentOp()));
273  return false;
274 }
275 
276 /// Return `true` if op dominance can be used to rule out a read-after-write
277 /// conflicts based on the ordering of ops. Returns `false` if op dominance
278 /// cannot be used to due region-based loops.
279 ///
280 /// Generalized op dominance can often be used to rule out potential conflicts
281 /// due to "read happens before write". E.g., the following IR is not a RaW
282 /// conflict because the read happens *before* the write.
283 ///
284 /// Example 1:
285 /// %0 = ... : tensor<?xf32> // DEF
286 /// "reading_op"(%0) : tensor<?xf32> // READ
287 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
288 ///
289 /// This is no longer true inside loops (or repetitive regions). In such cases,
290 /// there may not be a meaningful `happensBefore` relationship because ops
291 /// could be executed multiple times. E.g.:
292 ///
293 /// Example 2:
294 /// %0 = ... : tensor<?xf32> // DEF
295 /// scf.for ... {
296 /// "reading_op"(%0) : tensor<?xf32> // READ
297 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
298 /// ...
299 /// }
300 ///
301 /// In the above example, reading_op happens before writing_op according to
302 /// op dominance. However, both ops may happen multiple times; in
303 /// particular, the second execution of reading_op happens after the first
304 /// execution of writing_op. This is problematic because the tensor %0 they
305 /// operate on (i.e., the "definition") is defined outside of the loop.
306 ///
307 /// On a high-level, there is a potential RaW in a program if there exists a
308 /// possible program execution such that there is a sequence of DEF, followed
309 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
310 ///
311 /// E.g.:
312 /// No conflict: DEF, WRITE, DEF, READ
313 /// Potential conflict: DEF, READ, WRITE, READ, WRITE
314 ///
315 /// Example 1 has no conflict: DEF, READ, WRITE
316 /// Example 2 has a potential conflict: DEF, (READ, WRITE)*
317 //
318 /// Example 3:
319 /// scf.for ... {
320 /// %0 = ... : tensor<?xf32>
321 /// "reading_op"(%0) : tensor<?xf32>
322 /// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
323 /// ...
324 /// }
325 /// This has no conflict: (DEF, READ, WRITE)*
326 ///
327 /// Example 4:
328 /// %0 = ... : tensor<?xf32>
329 /// scf.for ... {
330 /// scf.for ... { "reading_op"(%0) }
331 /// %1 = "writing_op"(%0)
332 /// }
333 /// This has a potential conflict: DEF, ((READ)*, WRITE)*
334 ///
335 /// Example 5:
336 /// %0 = ... : tensor<?xf32>
337 /// scf.for ... { %1 = "writing_op"(%0) }
338 /// scf.for ... { "reading_op"(%0) }
339 /// This has a potential conflict: DEF, WRITE*, READ*
340 ///
341 /// The following rules are used to rule out RaW conflicts via ordering of ops:
342 ///
343 /// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
344 /// a repetitive region that enclosing both READ and WRITE, we cannot rule
345 /// out RaW conflict due to the ordering of ops.
346 /// 2. Otherwise: There are no loops that interfere with our analysis; for
347 /// analysis purposes, we can assume that there are no loops/repetitive
348 /// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
349 /// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
350 ///
352  const SetVector<Value> &definitions,
353  AnalysisState &state) {
354  const BufferizationOptions &options = state.getOptions();
355  for (Value def : definitions) {
356  Region *rRead =
357  state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
358  Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
359 
360  // READ and DEF are in the same repetitive region. `happensBefore` can be
361  // used to rule out RaW conflicts due to op ordering.
362  if (rRead == rDef)
363  continue;
364 
365  // Find the enclosing repetitive region of READ that is closest to DEF but
366  // not the repetitive region of DEF itself.
367  while (true) {
368  Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
369  if (nextRegion == rDef)
370  break;
371  assert(nextRegion && "expected to find another repetitive region");
372  rRead = nextRegion;
373  }
374 
375  // We cannot use op dominance if WRITE is inside the same repetitive region.
376  if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
377  return false;
378  }
379 
380  return true;
381 }
382 
383 /// Return `true` if op dominance can be used to rule out a read-after-write
384 /// conflicts based on the ordering of ops. Returns `false` if op dominance
385 /// cannot be used to due block-based loops within a region.
386 ///
387 /// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
388 /// how op domiance is used during RaW conflict detection.
389 ///
390 /// On a high-level, there is a potential RaW in a program if there exists a
391 /// possible program execution such that there is a sequence of DEF, followed
392 /// by WRITE, followed by READ. Each additional DEF resets the sequence.
393 ///
394 /// Op dominance cannot be used if there is a path from block(READ) to
395 /// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
396 /// not appear on that path.
398  const SetVector<Value> &definitions,
399  AnalysisState &state) {
400  // Fast path: If READ and WRITE are in different regions, their block cannot
401  // be reachable just via unstructured control flow. (Loops due to regions are
402  // covered by `canUseOpDominanceDueToRegions`.)
403  if (uRead->getOwner()->getParentRegion() !=
404  uWrite->getOwner()->getParentRegion())
405  return true;
406 
407  Block *readBlock = uRead->getOwner()->getBlock();
408  Block *writeBlock = uWrite->getOwner()->getBlock();
409  for (Value def : definitions) {
410  Block *defBlock = def.getParentBlock();
411  if (readBlock->isReachable(writeBlock, {defBlock}) &&
412  writeBlock->isReachable(readBlock, {defBlock}))
413  return false;
414  }
415 
416  return true;
417 }
418 
419 static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
420  const SetVector<Value> &definitions,
421  AnalysisState &state) {
422  return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
423  canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
424 }
425 
426 /// Annotate IR with details about the detected RaW conflict.
427 static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
428  Value definition) {
429  static uint64_t counter = 0;
430  Operation *readingOp = uRead->getOwner();
431  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
432 
433  OpBuilder b(conflictingWritingOp->getContext());
434  std::string id = "C_" + std::to_string(counter++);
435 
436  std::string conflictingWriteAttr =
437  id +
438  "[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
439  "]";
440  conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
441 
442  std::string readAttr =
443  id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
444  readingOp->setAttr(readAttr, b.getUnitAttr());
445 
446  if (auto opResult = dyn_cast<OpResult>(definition)) {
447  std::string defAttr =
448  id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
449  opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
450  } else {
451  auto bbArg = cast<BlockArgument>(definition);
452  std::string defAttr =
453  id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
454  bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
455  }
456 }
457 
458 /// Return 'true' if a tensor that is equivalent to `other` can be found in the
459 /// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
460 /// place along that use-def chain, the two tensors may not materialize as
461 /// equivalent buffers (but separate allocations).
462 ///
463 /// Note: This function also requires that the two tensors have equivalent
464 /// indexing. I.e., the tensor types do not change along the use-def chain,
465 /// apart from static <-> dynamic dim casts.
467  Value start, Value other) {
469  config.followEquivalentOnly = true;
470  config.alwaysIncludeLeaves = false;
471  config.followSameTypeOrCastsOnly = true;
472  return !state
473  .findValueInReverseUseDefChain(
474  start, [&](Value v) { return v == other; }, config)
475  .empty();
476 }
477 
478 /// Return "true" if `value` is originating from a subset that is equivalent to
479 /// the subset that `subsetOp` inserts into.
480 static bool matchesInsertDestination(const AnalysisState &state, Value value,
481  SubsetInsertionOpInterface subsetOp) {
482  auto matchingSubset = [&](Value val) {
483  if (auto opResult = dyn_cast<OpResult>(val))
484  if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
485  return state.areEquivalentBufferizedValues(v1, v2);
486  }))
487  return true;
488  return false;
489  };
490  // There may be multiple leaves at which the reverse SSA use-def chain lookup
491  // terminates. All of them must be equivalent subsets.
492  SetVector<Value> backwardSlice =
493  state.findValueInReverseUseDefChain(value, matchingSubset);
494  return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
495 }
496 
497 /// Return "true" if the given "read" and potentially conflicting "write" are
498 /// not conflicting due to their subset relationship. The comments in this
499 /// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
500 /// pairs, but apply to any subset ops that implement the
501 /// `SubsetInsertionOpInterface`.
503  OpOperand *uConflictingWrite,
504  const AnalysisState &state) {
505  Operation *readingOp = uRead->getOwner();
506  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
507 
508  // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
509  // uRead is an InsertSliceOp...
510  if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
511  // As an example, consider the following IR.
512  //
513  // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
514  // %1 = linalg.fill %cst, %0 {inplace= [true] }
515  // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
516  // {inplace= [true] }
517 
518  if (uRead == &subsetOp.getDestinationOperand() &&
519  matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
520  // Case 1: The main insight is that InsertSliceOp reads only part of
521  // the destination tensor. The overwritten area is not read. If
522  // uConflictingWrite writes into exactly the memory location that is
523  // being read by uRead, this is not a conflict.
524  //
525  // In the above example:
526  // uRead = OpOperand 1 (%t) of tensor.insert_slice
527  // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
528  //
529  // The read of %t does not conflict with the write of the FillOp
530  // (same aliases!) because the area that the FillOp operates on is
531  // exactly the one that is *not* read via %t.
532  return true;
533 
534  if (uRead == &subsetOp.getSourceOperand() &&
535  uConflictingWrite == &subsetOp.getDestinationOperand() &&
536  matchesInsertDestination(state, uRead->get(), subsetOp))
537  // Case 2: The read of the source tensor and the write to the dest
538  // tensor via an InsertSliceOp is not a conflict if the read is
539  // reading exactly that part of an equivalent tensor that the
540  // InsertSliceOp is writing.
541  //
542  // In the above example:
543  // uRead = OpOperand 0 (%1) of tensor.insert_slice
544  // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
545  return true;
546  }
547 
548  // If uConflictingWrite is an InsertSliceOp...
549  if (auto subsetOp =
550  dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
551  // As an example, consider the following IR.
552  //
553  // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
554  // %1 = linalg.fill %cst, %0 {inplace= [true] }
555  // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
556  // {inplace= [true] }
557  // %3 = vector.transfer_read %1, %cst
558  //
559  // In the above example:
560  // uRead = OpOperand 0 (%1) of vector.transfer_read
561  // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
562  // definition = %1
563  //
564  // This is not a conflict because the InsertSliceOp overwrites the
565  // memory segment of %1 with the exact same data. (Effectively, there
566  // is no memory write here.)
567  if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
568  state.areEquivalentBufferizedValues(
569  uRead->get(), subsetOp.getSourceOperand().get()) &&
570  matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
571  subsetOp))
572  return true;
573 
574  return false;
575 }
576 
577 /// Given sets of uses and writes, return true if there is a RaW conflict under
578 /// the assumption that all given reads/writes alias the same buffer and that
579 /// all given writes bufferize inplace.
580 ///
581 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
582 /// the result of a definition W1. But because of bufferization decisions, R
583 /// actually reads another definition W2.
584 static bool
586  const DenseSet<OpOperand *> &usesWrite,
587  const DominanceInfo &domInfo,
588  OneShotAnalysisState &state) {
589  const BufferizationOptions &options = state.getOptions();
590 
591  // Before going through the main RaW analysis, find cases where a buffer must
592  // be privatized due to parallelism. If the result of a write is never read,
593  // privatization is not necessary (and large parts of the IR are likely dead).
594  if (options.checkParallelRegions && !usesRead.empty()) {
595  for (OpOperand *uConflictingWrite : usesWrite) {
596  // Find the allocation point or last write (definition) of the buffer.
597  // Note: In contrast to `findDefinitions`, this also returns results of
598  // ops that do not bufferize to memory write when no other definition
599  // could be found. E.g., "bufferization.alloc_tensor" would be included,
600  // even though that op just bufferizes to an allocation but does define
601  // the contents of the buffer.
602  SetVector<Value> definitionsOrLeaves =
603  state.findValueInReverseUseDefChain(
604  uConflictingWrite->get(),
605  [&](Value v) { return state.bufferizesToMemoryWrite(v); });
606  assert(!definitionsOrLeaves.empty() &&
607  "expected at least one definition or leaf");
608 
609  // The writing op must bufferize out-of-place if the definition is in a
610  // different parallel region than this write.
611  for (Value def : definitionsOrLeaves) {
612  if (getParallelRegion(def.getParentRegion(), options) !=
613  getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
614  options)) {
615  LLVM_DEBUG(
616  llvm::dbgs()
617  << "\n- bufferizes out-of-place due to parallel region:\n");
618  LLVM_DEBUG(llvm::dbgs()
619  << " unConflictingWrite = operand "
620  << uConflictingWrite->getOperandNumber() << " of "
621  << *uConflictingWrite->getOwner() << "\n");
622  return true;
623  }
624  }
625  }
626  }
627 
628  for (OpOperand *uRead : usesRead) {
629  Operation *readingOp = uRead->getOwner();
630  LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
631  LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
632  << " of " << *readingOp << "\n");
633 
634  // Find the definition of uRead by following the SSA use-def chain.
635  // E.g.:
636  //
637  // %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
638  // %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
639  // %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
640  //
641  // In the above example, if uRead is the OpOperand of reading_op, the
642  // definition is %0. Note that operations that create an alias but do not
643  // bufferize to a memory write (such as ExtractSliceOp) are skipped.
644  const SetVector<Value> &definitions =
645  state.findDefinitionsCached(uRead->get());
646  if (definitions.empty()) {
647  // Fast path: No conflict if there are no definitions.
648  LLVM_DEBUG(llvm::dbgs()
649  << " no conflict: read value has no definitions\n");
650  continue;
651  }
652 
653  // Look for conflicting memory writes. Potential conflicts are writes to an
654  // alias that have been decided to bufferize inplace.
655  for (OpOperand *uConflictingWrite : usesWrite) {
656  LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
657  << uConflictingWrite->getOperandNumber() << " of "
658  << *uConflictingWrite->getOwner() << "\n");
659 
660  // Check if op dominance can be used to rule out read-after-write
661  // conflicts.
662  bool useDominance =
663  canUseOpDominance(uRead, uConflictingWrite, definitions, state);
664  LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
665 
666  // Throughout this loop, check for multiple requirements that have to be
667  // met for uConflictingWrite to be an actual conflict.
668  Operation *conflictingWritingOp = uConflictingWrite->getOwner();
669 
670  // Inside of repetitive regions, ops may be executed multiple times and op
671  // dominance cannot be used to rule out conflicts.
672  if (useDominance) {
673  // No conflict if the readingOp dominates conflictingWritingOp, i.e.,
674  // the write is not visible when reading.
675  //
676  // Note: If ops are executed multiple times (e.g., because they are
677  // inside a loop), there may be no meaningful `happensBefore`
678  // relationship.
679  if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
680  LLVM_DEBUG(llvm::dbgs()
681  << " no conflict: read happens before write\n");
682  continue;
683  }
684 
685  // No conflict if the reading use equals the use of the conflicting
686  // write. A use cannot conflict with itself.
687  //
688  // Note: Just being the same op is not enough. It has to be the same
689  // use.
690  // Note: If the op is executed multiple times (e.g., because it is
691  // inside a loop), it may be conflicting with itself.
692  if (uConflictingWrite == uRead) {
693  LLVM_DEBUG(llvm::dbgs()
694  << " no conflict: read and write are same use\n");
695  continue;
696  }
697 
698  // Ops are not conflicting if they are in mutually exclusive regions.
699  //
700  // Note: If ops are executed multiple times (e.g., because they are
701  // inside a loop), mutually exclusive regions may be executed
702  // multiple times.
703  if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) {
704  LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
705  "mutually exclusive regions\n");
706  continue;
707  }
708 
709  // Two equivalent operands of the same op are not conflicting if the op
710  // bufferizes to element-wise access. I.e., all loads at a position
711  // happen before all stores to the same position.
712  if (conflictingWritingOp == readingOp) {
713  if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
714  if (bufferizableOp.bufferizesToElementwiseAccess(
715  state, {uRead, uConflictingWrite})) {
717  state, uRead->get(), uConflictingWrite->get()) ||
719  state, uConflictingWrite->get(), uRead->get())) {
720  LLVM_DEBUG(
721  llvm::dbgs()
722  << " no conflict: op bufferizes to element-wise access\n");
723  continue;
724  }
725  }
726  }
727  }
728  }
729 
730  // No conflict if the operands are non-conflicting subsets.
731  if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
732  LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
733  continue;
734  }
735 
736  // No conflict if the op interface says so.
737  if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
738  if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
739  LLVM_DEBUG(llvm::dbgs()
740  << " no conflict: op interace of reading op says 'no'\n");
741  continue;
742  }
743  }
744 
745  if (conflictingWritingOp != readingOp) {
746  if (auto bufferizableOp =
747  options.dynCastBufferizableOp(conflictingWritingOp)) {
748  if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
749  state)) {
750  LLVM_DEBUG(
751  llvm::dbgs()
752  << " no conflict: op interace of writing op says 'no'\n");
753  continue;
754  }
755  }
756  }
757 
758  // Check all possible definitions.
759  for (Value definition : definitions) {
760  LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
761 
762  // No conflict if the conflicting write happens before the definition.
763  if (Operation *defOp = definition.getDefiningOp()) {
764  if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
765  // conflictingWritingOp happens before defOp. No conflict.
766  LLVM_DEBUG(llvm::dbgs()
767  << " no conflict: write happens before definition\n");
768  continue;
769  }
770  // No conflict if conflictingWritingOp is contained in defOp.
771  if (defOp->isProperAncestor(conflictingWritingOp)) {
772  LLVM_DEBUG(
773  llvm::dbgs()
774  << " no conflict: write is contained in definition\n");
775  continue;
776  }
777  } else {
778  auto bbArg = cast<BlockArgument>(definition);
779  Block *block = bbArg.getOwner();
780  if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
781  LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
782  "and write happens outside of block\n");
783  // conflictingWritingOp happens outside of the block. No
784  // conflict.
785  continue;
786  }
787  }
788 
789  // No conflict if the conflicting write and the definition are the same
790  // use.
791  AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
792  if (aliases.getNumAliases() == 1 &&
793  aliases.getAliases()[0].value == definition) {
794  LLVM_DEBUG(llvm::dbgs()
795  << " no conflict: definition and write are same\n");
796  continue;
797  }
798 
799  // All requirements are met. Conflict found!
800 
801  if (options.printConflicts)
802  annotateConflict(uRead, uConflictingWrite, definition);
803  LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
804  return true;
805  }
806  }
807  }
808 
809  return false;
810 }
811 
812 // Helper function to iterate on aliases of `root` and capture the writes.
814  const OneShotAnalysisState &state) {
815  state.applyOnAliases(root, [&](Value alias) {
816  for (auto &use : alias.getUses())
817  // Inplace write to a value that aliases root.
818  if (isInplaceMemoryWrite(use, state))
819  res.insert(&use);
820  });
821 }
822 
823 // Helper function to iterate on aliases of `root` and capture the reads.
825  const OneShotAnalysisState &state) {
826  state.applyOnAliases(root, [&](Value alias) {
827  for (auto &use : alias.getUses()) {
828  // Read of a value that aliases root.
829  if (state.bufferizesToMemoryRead(use)) {
830  res.insert(&use);
831  continue;
832  }
833 
834  // Read of a dependent value in the SSA use-def chain. E.g.:
835  //
836  // %0 = ...
837  // %1 = tensor.extract_slice %0 {not_analyzed_yet}
838  // "read"(%1)
839  //
840  // In the above example, getAliasingReads(%0) includes the first OpOperand
841  // of the tensor.extract_slice op. The extract_slice itself does not read
842  // but its aliasing result is eventually fed into an op that does.
843  //
844  // Note: This is considered a "read" only if the use does not bufferize to
845  // a memory write. (We already ruled out memory reads. In case of a memory
846  // write, the buffer would be entirely overwritten; in the above example
847  // there would then be no flow of data from the extract_slice operand to
848  // its result's uses.)
849  if (!state.bufferizesToMemoryWrite(use)) {
850  AliasingValueList aliases = state.getAliasingValues(use);
851  if (llvm::any_of(aliases, [&](AliasingValue a) {
852  return state.isValueRead(a.value);
853  }))
854  res.insert(&use);
855  }
856  }
857  });
858 }
859 
860 /// Return true if bufferizing `operand` inplace would create a conflict. A read
861 /// R and a write W of the same alias set is a conflict if inplace bufferization
862 /// of W changes the value read by R to a value different from the one that
863 /// would be expected by tracing back R's origin through SSA use-def chains.
864 /// A conflict can only be introduced by a new alias and/or an inplace
865 /// bufferization decision.
866 ///
867 /// Example:
868 /// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
869 /// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
870 /// %e = tensor.extract_slice %1
871 /// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
872 /// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
873 ///
874 /// In the above example, the two TransferWriteOps have already been decided to
875 /// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
876 /// conflict because:
877 /// * According to SSA use-def chains, we expect to read the result of %1.
878 /// * However, adding an alias {%0, %t} would mean that the second
879 /// TransferWriteOp overwrites the result of the first one. Therefore, the
880 /// TransferReadOp would no longer be reading the result of %1.
881 ///
882 /// If `checkConsistencyOnly` is true, this function checks if there is a
883 /// read-after-write conflict without bufferizing `operand` inplace. This would
884 /// indicate a problem with the current inplace bufferization decisions.
885 ///
886 /// Note: If `checkConsistencyOnly`, this function may be called with a null
887 /// OpResult. In that case, only the consistency of bufferization decisions
888 /// involving aliases of the given OpOperand are checked.
890  OpOperand &operand, const DominanceInfo &domInfo,
891  OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
892  // Collect reads and writes of all aliases of OpOperand and OpResult.
893  DenseSet<OpOperand *> usesRead, usesWrite;
894  getAliasingReads(usesRead, operand.get(), state);
895  getAliasingInplaceWrites(usesWrite, operand.get(), state);
896  for (AliasingValue alias : state.getAliasingValues(operand)) {
897  getAliasingReads(usesRead, alias.value, state);
898  getAliasingInplaceWrites(usesWrite, alias.value, state);
899  }
900  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
901  usesWrite.insert(&operand);
902 
903  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
904 }
905 
906 /// Annotate IR with details about the detected non-writability conflict.
907 static void annotateNonWritableTensor(Value value) {
908  static int64_t counter = 0;
909  OpBuilder b(value.getContext());
910  std::string id = "W_" + std::to_string(counter++);
911  if (auto opResult = dyn_cast<OpResult>(value)) {
912  std::string attr = id + "[NOT-WRITABLE: result " +
913  std::to_string(opResult.getResultNumber()) + "]";
914  opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
915  } else {
916  auto bbArg = cast<BlockArgument>(value);
917  std::string attr = id + "[NOT-WRITABLE: bbArg " +
918  std::to_string(bbArg.getArgNumber()) + "]";
919  bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
920  }
921 }
922 
923 /// Return true if bufferizing `operand` inplace would create a write to a
924 /// non-writable buffer.
925 static bool
927  OneShotAnalysisState &state,
928  bool checkConsistencyOnly = false) {
929  bool foundWrite =
930  !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
931 
932  if (!foundWrite) {
933  // Collect writes of all aliases of OpOperand and OpResult.
934  DenseSet<OpOperand *> usesWrite;
935  getAliasingInplaceWrites(usesWrite, operand.get(), state);
936  for (AliasingValue alias : state.getAliasingValues(operand))
937  getAliasingInplaceWrites(usesWrite, alias.value, state);
938  foundWrite = !usesWrite.empty();
939  }
940 
941  if (!foundWrite)
942  return false;
943 
944  // Look for a read-only tensor among all aliases.
945  bool foundReadOnly = false;
946  auto checkReadOnly = [&](Value v) {
947  if (!state.isWritable(v)) {
948  foundReadOnly = true;
949  if (state.getOptions().printConflicts)
951  }
952  };
953  state.applyOnAliases(operand.get(), checkReadOnly);
954  for (AliasingValue alias : state.getAliasingValues(operand))
955  state.applyOnAliases(alias.value, checkReadOnly);
956  if (foundReadOnly) {
957  LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
958  return true;
959  }
960 
961  return false;
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // Bufferization analyses.
966 //===----------------------------------------------------------------------===//
967 
968 // Find the values that define the contents of the given value.
970 OneShotAnalysisState::findDefinitionsCached(Value value) {
971  if (!cachedDefinitions.count(value))
972  cachedDefinitions[value] = findDefinitions(value);
973  return cachedDefinitions[value];
974 }
975 
978  cachedDefinitions.clear();
979 }
980 
981 /// Determine if `operand` can be bufferized in-place.
982 static LogicalResult
984  const DominanceInfo &domInfo) {
985  LLVM_DEBUG(
986  llvm::dbgs() << "//===-------------------------------------------===//\n"
987  << "Analyzing operand #" << operand.getOperandNumber()
988  << " of " << *operand.getOwner() << "\n");
989 
990  bool foundInterference =
991  wouldCreateWriteToNonWritableBuffer(operand, state) ||
992  wouldCreateReadAfterWriteInterference(operand, domInfo, state);
993 
994  if (foundInterference)
995  state.bufferizeOutOfPlace(operand);
996  else
997  state.bufferizeInPlace(operand);
998 
999  LLVM_DEBUG(llvm::dbgs()
1000  << "//===-------------------------------------------===//\n");
1001  return success();
1002 }
1003 
1004 LogicalResult
1006  const DominanceInfo &domInfo) {
1007  for (OpOperand &opOperand : op->getOpOperands())
1008  if (isa<TensorType>(opOperand.get().getType()))
1009  if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
1010  return failure();
1011  return success();
1012 }
1013 
1014 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
1016  OneShotAnalysisState &state) {
1017  for (Operation *op : ops) {
1018  if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
1019  for (OpResult opResult : op->getOpResults()) {
1020  if (!isa<TensorType>(opResult.getType()))
1021  continue;
1022  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1023  if (aliases.getNumAliases() == 0)
1024  // Nothing to do if there are no aliasing OpOperands.
1025  continue;
1026 
1027  Value firstOperand = aliases.begin()->opOperand->get();
1028  bool allEquivalent = true;
1029  for (AliasingOpOperand alias : aliases) {
1030  bool isEquiv = alias.relation == BufferRelation::Equivalent;
1031  bool isInPlace = state.isInPlace(*alias.opOperand);
1032  Value operand = alias.opOperand->get();
1033  if (isEquiv && isInPlace && alias.isDefinite) {
1034  // Found a definite, equivalent alias. Merge equivalence sets.
1035  // There can only be one definite alias, so we can stop here.
1036  state.unionEquivalenceClasses(opResult, operand);
1037  allEquivalent = false;
1038  break;
1039  }
1040  if (!isEquiv || !isInPlace)
1041  allEquivalent = false;
1042  if (!state.areEquivalentBufferizedValues(operand, firstOperand))
1043  allEquivalent = false;
1044  }
1045 
1046  // If all "maybe" aliases are equivalent and the OpResult is not a new
1047  // allocation, it is a definite, equivalent alias. E.g.:
1048  //
1049  // aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
1050  // aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
1051  // aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
1052  // %r = arith.select %c, %t0, %t1 : tensor<?xf32>
1053  //
1054  // If %t0 and %t1 are equivalent, it is safe to union the equivalence
1055  // classes of %r, %t0 and %t1.
1056  if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
1057  state.unionEquivalenceClasses(opResult, firstOperand);
1058  }
1059  }
1060  }
1061 }
1062 
1063 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
1064 /// in `op`.
1066  // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
1068  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
1069  // No tensors => no buffers.
1070  if (none_of(op->getResultTypes(), isaTensor))
1071  return;
1072  ops.push_back(op);
1073  });
1074 
1075  equivalenceAnalysis(ops, state);
1076 }
1077 
1078 /// "Bottom-up from terminators" heuristic.
1081  const OneShotAnalysisState &state) {
1082  SetVector<Operation *> traversedOps;
1083 
1084  // Find region terminators.
1085  op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
1086  if (!traversedOps.insert(term))
1087  return;
1088  // Follow the reverse SSA use-def chain from each yielded value as long as
1089  // we stay within the same region.
1090  SmallVector<OpResult> worklist;
1091  for (Value v : term->getOperands()) {
1092  if (!isa<TensorType>(v.getType()))
1093  continue;
1094  auto opResult = dyn_cast<OpResult>(v);
1095  if (!opResult)
1096  continue;
1097  worklist.push_back(opResult);
1098  }
1099  while (!worklist.empty()) {
1100  OpResult opResult = worklist.pop_back_val();
1101  Operation *defOp = opResult.getDefiningOp();
1102  if (!traversedOps.insert(defOp))
1103  continue;
1104  if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
1105  continue;
1106  AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
1107  for (auto alias : aliases) {
1108  Value v = alias.opOperand->get();
1109  if (!isa<TensorType>(v.getType()))
1110  continue;
1111  auto opResult = dyn_cast<OpResult>(v);
1112  if (!opResult)
1113  continue;
1114  worklist.push_back(opResult);
1115  }
1116  }
1117  });
1118 
1119  // Analyze traversed ops, then all remaining ops.
1120  SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
1122  if (!traversedOps.contains(op) && hasTensorSemantics(op))
1123  result.push_back(op);
1124  });
1125  return result;
1126 }
1127 
1129  const DominanceInfo &domInfo) {
1132 
1133  SmallVector<Operation *> orderedOps;
1134  if (heuristic ==
1136  orderedOps = bottomUpFromTerminatorsHeuristic(op, *this);
1137  } else {
1138  op->walk([&](Operation *op) {
1139  // No tensors => no buffers.
1140  if (!hasTensorSemantics(op))
1141  return;
1142  orderedOps.push_back(op);
1143  });
1144  switch (heuristic) {
1146  // Default: Walk ops in reverse for better interference analysis.
1147  std::reverse(orderedOps.begin(), orderedOps.end());
1148  break;
1149  }
1151  // Ops are already sorted top-down in `orderedOps`.
1152  break;
1153  }
1155  assert(getOptions().analysisFuzzerSeed &&
1156  "expected that fuzzer seed it set");
1157  // This is a fuzzer. For testing purposes only. Randomize the order in
1158  // which operations are analyzed. The bufferization quality is likely
1159  // worse, but we want to make sure that no assertions are triggered
1160  // anywhere.
1161  std::mt19937 g(getOptions().analysisFuzzerSeed);
1162  llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);
1163  break;
1164  }
1165  default: {
1166  llvm_unreachable("unsupported heuristic");
1167  }
1168  }
1169  }
1170 
1171  // Analyze ops in the computed order.
1172  for (Operation *op : orderedOps)
1173  if (failed(analyzeSingleOp(op, domInfo)))
1174  return failure();
1175 
1176  equivalenceAnalysis(op, *this);
1177  return success();
1178 }
1179 
1180 /// Perform various checks on the input IR to see if it contains IR constructs
1181 /// that are unsupported by One-Shot Bufferize.
1182 static LogicalResult
1184  OneShotAnalysisState &state) {
1185  const BufferizationOptions &options = state.getOptions();
1186 
1187  // Note: This walk cannot be combined with the one below because interface
1188  // methods of invalid/unsupported ops may be called during the second walk.
1189  // (On ops different from `op`.)
1190  WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
1191  // Skip ops that are not in the filter.
1192  if (!options.isOpAllowed(op.getOperation()))
1193  return WalkResult::advance();
1194 
1195  // Check for unsupported unstructured control flow.
1196  if (!op.supportsUnstructuredControlFlow()) {
1197  for (Region &r : op->getRegions()) {
1198  if (r.getBlocks().size() > 1) {
1199  op->emitOpError("op or BufferizableOpInterface implementation does "
1200  "not support unstructured control flow, but at least "
1201  "one region has multiple blocks");
1202  return WalkResult::interrupt();
1203  }
1204  }
1205  }
1206 
1207  return WalkResult::advance();
1208  });
1209  if (walkResult.wasInterrupted())
1210  return failure();
1211 
1212  walkResult = op->walk([&](BufferizableOpInterface op) {
1213  // Skip ops that are not in the filter.
1214  if (!options.isOpAllowed(op.getOperation()))
1215  return WalkResult::advance();
1216 
1217  // Input IR may not contain any ToTensorOps without the "restrict"
1218  // attribute. Such tensors may alias any other tensor, which is currently
1219  // not handled in the analysis.
1220  if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
1221  if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
1222  op->emitOpError("to_tensor ops without `restrict` are not supported by "
1223  "One-Shot Analysis");
1224  return WalkResult::interrupt();
1225  }
1226  }
1227 
1228  for (OpOperand &opOperand : op->getOpOperands()) {
1229  if (isa<TensorType>(opOperand.get().getType())) {
1231  opOperand, domInfo, state,
1232  /*checkConsistencyOnly=*/true)) {
1233  // This error can happen if certain "mustBufferizeInPlace" interface
1234  // methods are implemented incorrectly, such that the IR already has
1235  // a RaW conflict before making any bufferization decisions. It can
1236  // also happen if the bufferization.materialize_in_destination is used
1237  // in such a way that a RaW conflict is not avoidable.
1238  op->emitOpError("not bufferizable under the given constraints: "
1239  "cannot avoid RaW conflict");
1240  return WalkResult::interrupt();
1241  }
1242 
1243  if (state.isInPlace(opOperand) &&
1245  opOperand, state, /*checkConsistencyOnly=*/true)) {
1246  op->emitOpError("not bufferizable under the given constraints: would "
1247  "write to read-only buffer");
1248  return WalkResult::interrupt();
1249  }
1250  }
1251  }
1252 
1253  return WalkResult::advance();
1254  });
1255 
1256  return success(!walkResult.wasInterrupted());
1257 }
1258 
1259 /// Annotate the IR with the result of the analysis. For testing/debugging only.
1260 static void
1262  const OneShotAnalysisState &state) {
1263  // Add __inplace_operands_attr__.
1264  op->walk([&](Operation *op) {
1265  for (OpOperand &opOperand : op->getOpOperands())
1266  if (isa<TensorType>(opOperand.get().getType()))
1267  setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
1268  });
1269 }
1270 
1272  const OneShotAnalysisState &state) {
1273  AsmState asmState(op);
1274  Builder b(op->getContext());
1275  // Helper function to build an array attribute of aliasing SSA value strings.
1276  auto buildAliasesArray = [&](Value v) {
1277  SmallVector<Attribute> aliases;
1278  state.applyOnAliases(v, [&](Value alias) {
1279  std::string buffer;
1280  llvm::raw_string_ostream stream(buffer);
1281  alias.printAsOperand(stream, asmState);
1282  aliases.push_back(b.getStringAttr(buffer));
1283  });
1284  return b.getArrayAttr(aliases);
1285  };
1286 
1287  op->walk([&](Operation *op) {
1288  // Build alias set array for every OpResult.
1289  SmallVector<Attribute> opResultAliasSets;
1290  for (OpResult opResult : op->getOpResults()) {
1291  if (llvm::isa<TensorType>(opResult.getType())) {
1292  opResultAliasSets.push_back(buildAliasesArray(opResult));
1293  }
1294  }
1295  if (!opResultAliasSets.empty())
1296  op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
1297 
1298  // Build alias set array for every BlockArgument.
1299  SmallVector<Attribute> regionAliasSets;
1300  bool hasTensorBbArg = false;
1301  for (Region &r : op->getRegions()) {
1302  SmallVector<Attribute> blockAliasSets;
1303  for (Block &block : r.getBlocks()) {
1304  SmallVector<Attribute> bbArgAliasSets;
1305  for (BlockArgument bbArg : block.getArguments()) {
1306  if (llvm::isa<TensorType>(bbArg.getType())) {
1307  bbArgAliasSets.push_back(buildAliasesArray(bbArg));
1308  hasTensorBbArg = true;
1309  }
1310  }
1311  blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
1312  }
1313  regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
1314  }
1315  if (hasTensorBbArg)
1316  op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
1317  });
1318 }
1319 
1321  OneShotAnalysisState &state,
1322  BufferizationStatistics *statistics) {
1323  DominanceInfo domInfo(op);
1324  const OneShotBufferizationOptions &options = state.getOptions();
1325 
1326  if (failed(checkPreBufferizationAssumptions(op, domInfo, state)))
1327  return failure();
1328 
1329  // If the analysis fails, just return.
1330  if (failed(state.analyzeOp(op, domInfo)))
1331  return failure();
1332 
1333  if (statistics) {
1334  statistics->numTensorInPlace = state.getStatNumTensorInPlace();
1335  statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
1336  }
1337 
1338  bool failedAnalysis = false;
1339 
1340  // Gather some extra analysis data.
1341  state.gatherUndefinedTensorUses(op);
1342 
1343  // Analysis verification: After setting up alias/equivalence sets, each op
1344  // can check for expected invariants/limitations and fail the analysis if
1345  // necessary.
1346  op->walk([&](Operation *op) {
1347  if (BufferizableOpInterface bufferizableOp =
1348  options.dynCastBufferizableOp(op))
1349  failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
1350  });
1351 
1352  // Annotate operations if we only want to report the analysis.
1353  if (options.testAnalysisOnly)
1355  if (options.dumpAliasSets)
1356  annotateOpsWithAliasSets(op, state);
1357 
1358  return success(!failedAnalysis);
1359 }
1360 
1361 LogicalResult
1364  BufferizationStatistics *statistics) {
1365  // copy-before-write deactivates the analysis. It cannot be used together with
1366  // test-analysis-only.
1367  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
1368  "invalid combination of bufferization flags");
1369 
1370  if (options.copyBeforeWrite) {
1371  // Copy buffer before each write. No analysis is needed.
1372  } else {
1373  // Run One-Shot Analysis and insert buffer copies (on the tensor level)
1374  // only where needed. This is the default and much more efficient than
1375  // copy-before-write.
1376  if (failed(insertTensorCopies(op, options, statistics)))
1377  return failure();
1378 
1379  // If test-analysis-only is set, the IR was annotated with RaW conflict
1380  // markers (attributes) during One-Shot Analysis.
1381  if (options.testAnalysisOnly)
1382  return success();
1383  }
1384 
1385  // Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
1386  // a new buffer copy is allocated every time a buffer is written to.
1387  return bufferizeOp(op, options, statistics);
1388 }
static bool hasReadAfterWriteInterference(const DenseSet< OpOperand * > &usesRead, const DenseSet< OpOperand * > &usesWrite, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Given sets of uses and writes, return true if there is a RaW conflict under the assumption that all g...
static void getAliasingReads(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static void equivalenceAnalysis(SmallVector< Operation * > &ops, OneShotAnalysisState &state)
Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace)
Mark whether OpOperand will be bufferized inplace.
static SmallVector< Operation * > bottomUpFromTerminatorsHeuristic(Operation *op, const OneShotAnalysisState &state)
"Bottom-up from terminators" heuristic.
constexpr StringLiteral kInPlaceOperandsAttrName
Attribute marker to specify op operands that bufferize in-place.
static bool isaTensor(Type t)
static void annotateNonWritableTensor(Value value)
Annotate IR with details about the detected non-writability conflict.
static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...
static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo)
Determine if operand can be bufferized in-place.
static bool matchesInsertDestination(const AnalysisState &state, Value value, SubsetInsertionOpInterface subsetOp)
Return "true" if value is originating from a subset that is equivalent to the subset that subsetOp in...
constexpr StringLiteral kOpResultAliasSetAttrName
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, Value start, Value other)
Return 'true' if a tensor that is equivalent to other can be found in the reverse use-def chain of st...
static bool happensBefore(Operation *a, Operation *b, const DominanceInfo &domInfo)
Return true if a happens before b, i.e., a or one of its ancestors properly dominates b and b is not ...
static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
static bool wouldCreateWriteToNonWritableBuffer(OpOperand &operand, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a write to a non-writable buffer.
static void annotateOpsWithAliasSets(Operation *op, const OneShotAnalysisState &state)
static LogicalResult checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, OneShotAnalysisState &state)
Perform various checks on the input IR to see if it contains IR constructs that are unsupported by On...
static void annotateOpsWithBufferizationMarkers(Operation *op, const OneShotAnalysisState &state)
Annotate the IR with the result of the analysis. For testing/debugging only.
static bool wouldCreateReadAfterWriteInterference(OpOperand &operand, const DominanceInfo &domInfo, OneShotAnalysisState &state, bool checkConsistencyOnly=false)
Return true if bufferizing operand inplace would create a conflict.
constexpr StringLiteral kBbArgAliasSetAttrName
static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite, const SetVector< Value > &definitions, AnalysisState &state)
Return true if op dominance can be used to rule out a read-after-write conflicts based on the orderin...
static void getAliasingInplaceWrites(DenseSet< OpOperand * > &res, Value root, const OneShotAnalysisState &state)
static bool areNonConflictingSubsets(OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state)
Return "true" if the given "read" and potentially conflicting "write" are not conflicting due to thei...
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value definition)
Annotate IR with details about the detected RaW conflict.
static bool isInplaceMemoryWrite(OpOperand &opOperand, const OneShotAnalysisState &state)
Return true if opOperand has been decided to bufferize in-place.
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
Base class for generic analysis states.
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:540
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * findAncestorOpInBlock(Operation &op)
Returns 'op' if 'op' lies in this block, or otherwise finds the ancestor operation of 'op' that lies ...
Definition: Block.cpp:76
bool isReachable(Block *other, SmallPtrSet< Block *, 16 > &&except={})
Return "true" if there is a path from this block to the given block (according to the successors rela...
Definition: Block.cpp:355
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class helps build Operations.
Definition: Builders.h:216
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
result_range getOpResults()
Definition: Operation.h:420
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
AnalysisState provides a variety of helper functions for dealing with tensor values.
AliasingValueList getAliasingValues(OpOperand &opOperand) const
Determine which Value will alias with opOperand if the op is bufferized in place.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
SetVector< Value > findDefinitions(Value value) const
Find the values that may define the contents of the given value at runtime.
virtual ~Extension()
Base virtual destructor.
State for analysis-enabled bufferization.
void bufferizeOutOfPlace(OpOperand &operand)
Mark the given OpOperand as out-of-place.
bool isWritable(Value value) const
Return true if the buffer of the given tensor value is writable.
const SetVector< Value > & findDefinitionsCached(Value value)
Find the definitions of the given tensor value or retrieve them from the cache.
bool isInPlace(OpOperand &opOperand) const override
Return true if the given OpResult has been decided to bufferize inplace.
LogicalResult analyzeOp(Operation *op, const DominanceInfo &domInfo)
Analyze the given op and its nested ops.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
void unionEquivalenceClasses(Value v1, Value v2)
Union the equivalence classes of v1 and v2.
void gatherUndefinedTensorUses(Operation *op)
Find all tensor values in the given operation that have undefined contents and store them in undefine...
void resetCache() override
Reset cached data structures.
LogicalResult analyzeSingleOp(Operation *op, const DominanceInfo &domInfo)
Analyze a single op (without nested ops).
void applyOnEquivalenceClass(Value v, function_ref< void(Value)> fun) const
Apply fun to all the members of the equivalence class of v.
bool hasUndefinedContents(OpOperand *opOperand) const override
Return true if the given tensor has undefined contents.
void bufferizeInPlace(OpOperand &operand)
Mark the given OpOperand as in-place and merge the results' and operand's aliasing sets.
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
OneShotAnalysisState(Operation *op, const OneShotBufferizationOptions &options)
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
void unionAliasSets(Value v1, Value v2)
Union the alias sets of v1 and v2.
void createAliasInfoEntry(Value v)
Add a new entry for v in the aliasInfo and equivalentInfo.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:305
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
Region * getParallelRegion(Region *region, const BufferizationOptions &options)
If region is a parallel region, return region.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Definition: Iterators.h:29
Options for BufferizableOpInterface-based bufferization.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
Traversal parameters for findValueInReverseUseDefChain.