MLIR  20.0.0git
ConvertToDestinationStyle.cpp
Go to the documentation of this file.
1 //===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns to convert non-DPS ops to DPS ops. New
10 // tensor.empty ops are inserted as a destination. Such tensor.empty can be
11 // eliminated with "empty tensor elimination", allowing them to bufferize
12 // without an allocation (assuming there are no further conflicts).
13 //
14 //===----------------------------------------------------------------------===//
15 //
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
28 
29 using namespace mlir;
30 using namespace mlir::tensor;
31 
32 // Implements backtracking to traverse indices of the output buffer while
33 // iterating over op.elements().
34 static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
35  Value destination, ArrayRef<int64_t> shape,
36  ArrayRef<Value> constants,
37  OperandRange::iterator &elementIt,
38  SmallVectorImpl<Value> &indices) {
39  if (dim == static_cast<int>(shape.size()) - 1) {
40  for (int i = 0; i < shape.back(); ++i) {
41  indices.back() = constants[i];
42  destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
43  destination, indices);
44  ++elementIt;
45  }
46  return destination;
47  }
48  for (int i = 0; i < shape[dim]; ++i) {
49  indices[dim] = constants[i];
50  destination = createInserts(rewriter, loc, dim + 1, destination, shape,
51  constants, elementIt, indices);
52  }
53  return destination;
54 }
55 
56 /// Create a memcpy from the given source tensor to the given destination
57 /// memref. The copy op type can be specified in the `options`.
58 static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
59  Value memrefDest,
61  auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
62  assert(tensorType && "expected ranked tensor");
63  assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
64 
65  switch (options.memcpyOp) {
68  // Note: This is the preferred way of memcpy'ing because no layout map
69  // and/or memory space must be specified for the source.
70  auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71  loc, tensorSource, memrefDest);
72  materializeOp.setWritable(true);
73  } break;
75  // TODO: Support custom memory space on source.
76  // We do not know the layout map of the source yet, so use a fully dynamic
77  // layout for best compatibility.
78  Value toMemref = b.create<bufferization::ToMemrefOp>(
80  tensorSource, /*readOnly=*/true);
81  b.create<memref::CopyOp>(loc, toMemref, memrefDest);
82  } break;
84  // TODO: Support custom memory space on source.
85  // We do not know the layout map of the source yet, so use a fully dynamic
86  // layout for best compatibility.
87  Value toMemref = b.create<bufferization::ToMemrefOp>(
89  tensorSource, /*readOnly=*/true);
90  b.create<linalg::CopyOp>(loc, toMemref, memrefDest);
91  } break;
92  };
93 }
94 
96  Location loc, PadOp padOp,
97  Value dest) {
98  OpBuilder::InsertionGuard g(rewriter);
99  RankedTensorType resultType = padOp.getResultType();
100 
101  // Examine the yielded value to decide if a linalg.generic is neede or a
102  // linalg.fill is sufficient.
103  Value yieldedValue =
104  cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
105  Attribute constYieldedValue;
106  // Is the yielded value a bbArg defined outside of the PadOp?
107  bool outsideBbArg =
108  isa<BlockArgument>(yieldedValue) &&
109  cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
110  padOp.getOperation();
111  // Is the yielded value an OpResult defined outside of the PadOp?
112  bool outsideOpResult =
113  isa<OpResult>(yieldedValue) &&
114  yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
115  bool invariantYieldedValue = outsideBbArg || outsideOpResult;
116  if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
117  // Padding with a constant: Create linalg.fill.
118  Dialect *arithDialect =
119  rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
120  Value fillValue =
121  arithDialect
122  ->materializeConstant(rewriter, constYieldedValue,
123  yieldedValue.getType(), yieldedValue.getLoc())
124  ->getResult(0);
125  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
126  ValueRange(dest));
127  return fillOp;
128  }
129 
130  if (invariantYieldedValue) {
131  // Padding with an invariant value.
132  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
133  ValueRange(dest));
134  return fillOp;
135  }
136 
137  // Create linalg.generic.
138  SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
139  utils::IteratorType::parallel);
140  SmallVector<AffineMap> indexingMaps(
141  1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
142  auto genericOp = rewriter.create<linalg::GenericOp>(
143  loc, resultType, /*inputs=*/ValueRange(),
144  /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
145  indexingMaps, iteratorTypes);
146  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
147  resultType.getElementType(), loc);
148  rewriter.setInsertionPointToStart(body);
149  SmallVector<Value> bbArgReplacements;
150  for (int64_t i = 0; i < resultType.getRank(); ++i)
151  bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
152  rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
153 
154  // Update terminator.
155  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
156  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
157  return genericOp;
158 }
159 
161  Value value) {
162  auto tensorType = cast<RankedTensorType>(value.getType());
163  if (tensorType.hasStaticShape())
164  return {};
165 
166  // Try to reify dynamic sizes.
167  ReifiedRankedShapedTypeDims reifiedShape;
168  if (isa<OpResult>(value) &&
169  succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
170  SmallVector<Value> dynSizes;
171  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
172  if (tensorType.isDynamicDim(i))
173  dynSizes.push_back(
174  reifiedShape[cast<OpResult>(value).getResultNumber()][i]
175  .get<Value>());
176  }
177  return dynSizes;
178  }
179 
180  // Create tensor.dim ops.
181  SmallVector<Value> dynSizes;
182  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
183  if (tensorType.isDynamicDim(i))
184  dynSizes.push_back(
185  b.create<DimOp>(value.getLoc(), value,
186  b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
187  }
188  return dynSizes;
189 }
190 
191 static Value
194  Attribute memorySpace = {}) {
195  OpBuilder::InsertionGuard g(rewriter);
196  auto tensorType = cast<RankedTensorType>(value.getType());
197 
198  // Create buffer allocation.
199  auto memrefType =
201  tensorType, memorySpace));
202  SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
203 
204  Value alloc;
205  if (options.allocOp ==
207  alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
208  if (options.emitDealloc) {
209  // Place deallocation at the end of the block.
210  rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
211  rewriter.create<memref::DeallocOp>(loc, alloc);
212  }
213  } else if (options.allocOp ==
215  alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
216  // No dealloc is needed.
217  }
218 
219  return alloc;
220 }
221 
224  PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
225  // tensor.pad does not have a destination operand.
226  assert(!options.bufferizeDestinationOnly && "invalid options");
227 
228  OpBuilder::InsertionGuard g(rewriter);
229  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
230  Location loc = padOp.getLoc();
231 
232  // Create buffer allocation.
233  Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
234  options, memorySpace);
235  rewriter.setInsertionPoint(padOp);
236 
237  if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
238  // Create linalg.fill or linalg.generic. Not needed if there is no padding.
239  Operation *fillOp =
240  movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
241  rewriter.setInsertionPointAfter(fillOp);
242  }
243 
244  // Create memcpy.
246  getMixedSizes(rewriter, loc, padOp.getSource());
247  SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
248  rewriter.getIndexAttr(1));
249  Value subview = rewriter.create<memref::SubViewOp>(
250  loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
251  createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
252 
253  // Create bufferization.to_tensor with "restrict" and "writable". The returned
254  // tensor is a new buffer allocation, so it does not alias with any buffer.
255  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
256  loc, alloc, /*restrict=*/true, /*writable=*/true);
257  rewriter.replaceOp(padOp, toTensorOp);
258  return alloc;
259 }
260 
263  vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
264  assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
265  "expected single masked op");
266  OpBuilder::InsertionGuard g(rewriter);
267  bufferization::BufferizationOptions bufferizationOptions;
268  Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
269  assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
270 
271  // Bufferize maskable op. By default, place the buffer allocation right before
272  // the mask op.
274  rewriter, options, maskOp.getMaskableOp(), memorySpace,
275  /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
276 
277  if (options.bufferizeDestinationOnly)
278  return alloc;
279 
280  // Bufferize terminator.
281  rewriter.setInsertionPoint(yieldOp);
282  if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
283  rewriter, bufferizationOptions)))
284  return nullptr;
285 
286  // Erase dead to_tensor ops inside of the mask op. This is necessary because
287  // there only be one op (apart from the terminator) inside the mask op.
288  // TODO: Remove dead to_tensor ops more aggressively during bufferization.
289  SmallVector<Operation *> toTensorOps;
290  maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
291  if (toTensorOp->getUses().empty())
292  toTensorOps.push_back(toTensorOp.getOperation());
293  });
294  for (Operation *op : toTensorOps)
295  rewriter.eraseOp(op);
296 
297  // Bufferize mask op.
298  SmallVector<OpOperand *> resultUses;
299  for (Value result : maskOp.getResults())
300  if (isa<TensorType>(result.getType()))
301  for (OpOperand &use : result.getUses())
302  resultUses.push_back(&use);
303  rewriter.setInsertionPoint(maskOp);
304  if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
305  .bufferize(rewriter, bufferizationOptions)))
306  return nullptr;
307 
308  // Set "restrict" attribute, indicating that no other tensor aliases with
309  // this tensor. That is because we just allocated a new buffer for the tensor.
310  for (OpOperand *resultUse : resultUses) {
311  auto toTensorOp =
312  resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
313  assert(toTensorOp && "expected to_tensor op");
314  rewriter.modifyOpInPlace(toTensorOp, [&]() {
315  toTensorOp.setRestrict(true);
316  toTensorOp.setWritable(true);
317  });
318  }
319 
320  return alloc;
321 }
322 
325  bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
326  Operation *insertionPoint) {
327  Location loc = allocTensorOp.getLoc();
328  OpBuilder::InsertionGuard g(rewriter);
329  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
330  bufferization::BufferizationOptions bufferizationOptions;
331 
332  // Create buffer allocation.
334  rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
335 
336  // Create bufferization.to_tensor with "restrict" and "writable". The returned
337  // tensor is a new buffer allocation, so it does not alias with any buffer.
338  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
339  loc, alloc, /*restrict=*/true, /*writable=*/true);
340  rewriter.replaceOp(allocTensorOp, toTensorOp);
341  return alloc;
342 }
343 
344 /// Lower tensor.from_elements to a sequence of chained tensor.insert.
346  RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
347  Location loc = fromElementsOp.getLoc();
348  RankedTensorType tensorType =
349  cast<RankedTensorType>(fromElementsOp.getType());
350  auto shape = tensorType.getShape();
351 
352  // Create tensor.empty.
353  auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
354 
355  // Case: tensor<elem_type>.
356  if (shape.empty()) {
357  Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
358  fromElementsOp, fromElementsOp.getElements().front(),
359  emptyOp.getResult(), ValueRange());
360  return res;
361  }
362 
363  // Create constants for the range of possible indices [0, max{shape_i}).
364  auto maxDim = *llvm::max_element(shape);
365  SmallVector<Value, 2> constants;
366  constants.reserve(maxDim);
367  for (int i = 0; i < maxDim; ++i)
368  constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
369 
370  // Traverse all elements and create tensor.insert ops.
371  auto elementIt = fromElementsOp.getElements().begin();
372  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
373  Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
374  shape, constants, elementIt, indices);
375 
376  // Replace tensor.from_elements.
377  rewriter.replaceOp(fromElementsOp, result);
378  return result.getDefiningOp();
379 }
380 
381 /// Lower tensor.generate to linalg.generic.
382 FailureOr<Operation *>
384  tensor::GenerateOp generateOp) {
385  // Only ops with exactly one block are supported.
386  if (!generateOp.getBody().hasOneBlock())
387  return failure();
388 
389  Location loc = generateOp.getLoc();
390  RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
391 
392  // Create tensor.empty.
393  auto emptyOp =
394  rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
395 
396  // Create linalg.generic.
397  SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
398  utils::IteratorType::parallel);
399  SmallVector<AffineMap> indexingMaps(
400  1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
401  auto genericOp = rewriter.create<linalg::GenericOp>(
402  loc, tensorType, /*inputs=*/ValueRange(),
403  /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
404  indexingMaps, iteratorTypes);
405  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
406  tensorType.getElementType(), loc);
407  rewriter.setInsertionPointToStart(body);
408  SmallVector<Value> bbArgReplacements;
409  for (int64_t i = 0; i < tensorType.getRank(); ++i)
410  bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
411  rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
412 
413  // Update terminator.
414  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
415  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
416 
417  // Replace tensor.generate.
418  rewriter.replaceOp(generateOp, genericOp->getResult(0));
419  return genericOp.getOperation();
420 }
421 
422 /// Lower tensor.pad to linalg.generic + tensor.insert_slice.
423 FailureOr<Operation *>
425  tensor::PadOp padOp) {
426  // Only ops with exactly one block are supported.
427  if (!padOp.getBodyRegion().hasOneBlock())
428  return failure();
429 
430  // Create tensor.empty.
431  Location loc = padOp.getLoc();
432  RankedTensorType resultType = padOp.getResultType();
433  ReifiedRankedShapedTypeDims reifiedShape;
434  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
435  return rewriter.notifyMatchFailure(
436  padOp, "failed to reify tensor.pad op result shape");
437  SmallVector<Value> dynamicSizes;
438  for (int64_t i = 0; i < resultType.getRank(); ++i)
439  if (resultType.isDynamicDim(i))
440  dynamicSizes.push_back(reifiedShape[0][i].get<Value>());
441 
442  // If the `padOp` has a nofold attribute and all paddings are known to be 0,
443  // explicitly insert a `linalg.copy`.
444  if (padOp.getNofoldAttr() &&
445  llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
446  llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
447  using bufferization::AllocTensorOp;
448  Value allocated =
449  rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
450  auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
451  padOp, padOp.getSource(), allocated);
452  return copyOp.getOperation();
453  }
454 
455  Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
456  // Create linalg.fill or linalg.generic.
457  Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
458  rewriter.setInsertionPointAfter(fillOp);
459 
460  // Create tensor::InsertSliceOp.
461  SmallVector<OpFoldResult> sliceSizes =
462  getMixedSizes(rewriter, loc, padOp.getSource());
463  SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
464  rewriter.getIndexAttr(1));
465  auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
466  padOp, padOp.getSource(), fillOp->getResult(0),
467  /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
468  return insertSliceOp.getOperation();
469 }
470 
473  Operation *op, Attribute memorySpace, Operation *insertionPoint) {
474  using namespace bufferization;
475 
476  // Call specialized overload for certain ops.
477  if (auto padOp = dyn_cast<tensor::PadOp>(op))
478  return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
479  if (auto maskOp = dyn_cast<vector::MaskOp>(op))
480  return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
481  if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
482  return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
483 
484  // Only bufferizable ops are supported.
485  auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
486  if (!bufferizableOp)
487  return nullptr;
488  BufferizationOptions bufferizationOptions;
489  AnalysisState state(bufferizationOptions);
490 
491 #ifndef NDEBUG
492  if (!options.bufferizeDestinationOnly) {
493  // Ops with nested tensor ops are not supported yet. At the moment, this
494  // function just bufferizes the given op itself, but not its body.
495  op->walk([&](Operation *nestedOp) {
496  if (op == nestedOp)
497  return;
498  if (llvm::any_of(nestedOp->getOperands(),
499  [](Value v) { return isa<TensorType>(v.getType()); }))
500  llvm_unreachable("ops with nested tensor ops are not supported yet");
501  if (llvm::any_of(nestedOp->getResults(),
502  [](Value v) { return isa<TensorType>(v.getType()); }))
503  llvm_unreachable("ops with nested tensor ops are not supported yet");
504  });
505  }
506 #endif // NDEBUG
507 
508  // Gather tensor results.
509  SmallVector<OpResult> tensorResults;
510  for (OpResult result : op->getResults()) {
511  if (!isa<TensorType>(result.getType()))
512  continue;
513  // Unranked tensors are not supported
514  if (!isa<RankedTensorType>(result.getType()))
515  return nullptr;
516  // Ops that bufferize to an allocation are not supported.
517  if (bufferizableOp.bufferizesToAllocation(result))
518  return nullptr;
519  tensorResults.push_back(result);
520  }
521 
522  // Gather all operands that should bufferize to a new allocation. I.e.,
523  // bufferize out-of-place.
524  SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
525  auto addOutOfPlaceOperand = [&](OpOperand *operand) {
526  if (!llvm::is_contained(outOfPlaceOperands, operand))
527  outOfPlaceOperands.push_back(operand);
528  };
529  for (OpResult result : tensorResults) {
530  AliasingOpOperandList aliasingOperands =
531  state.getAliasingOpOperands(result);
532  for (const AliasingOpOperand &operand : aliasingOperands) {
533  addOutOfPlaceOperand(operand.opOperand);
534  for (OpOperand &resultUse : result.getUses())
535  resultUses.push_back(&resultUse);
536  }
537  }
538  for (OpOperand &operand : op->getOpOperands()) {
539  if (!state.bufferizesToMemoryWrite(operand))
540  continue;
541  if (!isa<RankedTensorType>(operand.get().getType()))
542  continue;
543  addOutOfPlaceOperand(&operand);
544  }
545  // TODO: Support multiple buffers.
546  if (outOfPlaceOperands.size() != 1)
547  return nullptr;
548 
549  // Allocate buffers.
550  OpBuilder::InsertionGuard g(rewriter);
551  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
552  SmallVector<Value> allocs;
553  for (OpOperand *operand : outOfPlaceOperands) {
555  rewriter, op->getLoc(), operand->get(), options, memorySpace);
556  allocs.push_back(alloc);
557  if (!state.findDefinitions(operand->get()).empty()) {
558  // Initialize buffer with a copy of the operand data. Not needed if the
559  // tensor is uninitialized.
560  createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
561  }
562  rewriter.modifyOpInPlace(op, [&]() {
563  auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
564  operand->set(toTensorOp);
565  if (options.bufferizeDestinationOnly) {
566  rewriter.modifyOpInPlace(toTensorOp, [&]() {
567  toTensorOp.setRestrict(true);
568  toTensorOp.setWritable(true);
569  });
570  }
571  });
572  }
573 
574  if (options.bufferizeDestinationOnly)
575  return allocs.front();
576 
577  // Bufferize the op.
578  rewriter.setInsertionPoint(op);
579  if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
580  return nullptr;
581 
582  // Set "restrict" attribute, indicating that no other tensor aliases with
583  // this tensor. That is because we just allocated a new buffer for the tensor.
584  for (OpOperand *resultUse : resultUses) {
585  auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
586  assert(toTensorOp && "expected to_tensor op");
587  rewriter.modifyOpInPlace(toTensorOp, [&]() {
588  toTensorOp.setRestrict(true);
589  toTensorOp.setWritable(true);
590  });
591  }
592  return allocs.front();
593 }
594 
595 namespace {
596 
597 template <typename OpTy>
598 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
599  PatternRewriter &rewriter) {
600  return linalg::rewriteInDestinationPassingStyle(rewriter, op);
601 }
602 
603 } // namespace
604 
606  RewritePatternSet &patterns) {
607  patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
608  patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
609  patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
610 }
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:398
MLIRContext * getContext() const
Definition: Builders.h:55
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:83
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:445
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
Options for BufferizableOpInterface-based bufferization.