MLIR  19.0.0git
Go to the documentation of this file.
1 //===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns to convert non-DPS ops to DPS ops. New
10 // tensor.empty ops are inserted as a destination. Such tensor.empty can be
11 // eliminated with "empty tensor elimination", allowing them to bufferize
12 // without an allocation (assuming there are no further conflicts).
13 //
14 //===----------------------------------------------------------------------===//
15 //
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
29 using namespace mlir;
30 using namespace mlir::tensor;
32 // Implements backtracking to traverse indices of the output buffer while
33 // iterating over op.elements().
34 static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
35  Value destination, ArrayRef<int64_t> shape,
36  ArrayRef<Value> constants,
37  OperandRange::iterator &elementIt,
38  SmallVectorImpl<Value> &indices) {
39  if (dim == static_cast<int>(shape.size()) - 1) {
40  for (int i = 0; i < shape.back(); ++i) {
41  indices.back() = constants[i];
42  destination = rewriter.create<tensor::InsertOp>(loc, *elementIt,
43  destination, indices);
44  ++elementIt;
45  }
46  return destination;
47  }
48  for (int i = 0; i < shape[dim]; ++i) {
49  indices[dim] = constants[i];
50  destination = createInserts(rewriter, loc, dim + 1, destination, shape,
51  constants, elementIt, indices);
52  }
53  return destination;
54 }
56 /// Create a memcpy from the given source tensor to the given destination
57 /// memref. The copy op type can be specified in the `options`.
58 static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
59  Value memrefDest,
61  auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
62  assert(tensorType && "expected ranked tensor");
63  assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
65  switch (options.memcpyOp) {
68  // Note: This is the preferred way of memcpy'ing because no layout map
69  // and/or memory space must be specified for the source.
70  auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>(
71  loc, tensorSource, memrefDest);
72  materializeOp.setWritable(true);
73  } break;
75  // TODO: Support custom memory space on source.
76  // We do not know the layout map of the source yet, so use a fully dynamic
77  // layout for best compatibility.
78  Value toMemref = b.create<bufferization::ToMemrefOp>(
80  tensorSource, /*readOnly=*/true);
81  b.create<memref::CopyOp>(loc, toMemref, memrefDest);
82  } break;
84  // TODO: Support custom memory space on source.
85  // We do not know the layout map of the source yet, so use a fully dynamic
86  // layout for best compatibility.
87  Value toMemref = b.create<bufferization::ToMemrefOp>(
89  tensorSource, /*readOnly=*/true);
90  b.create<linalg::CopyOp>(loc, toMemref, memrefDest);
91  } break;
92  };
93 }
96  Location loc, PadOp padOp,
97  Value dest) {
98  OpBuilder::InsertionGuard g(rewriter);
99  RankedTensorType resultType = padOp.getResultType();
101  // Examine the yielded value to decide if a linalg.generic is neede or a
102  // linalg.fill is sufficient.
103  Value yieldedValue =
104  cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
105  Attribute constYieldedValue;
106  // Is the yielded value a bbArg defined outside of the PadOp?
107  bool outsideBbArg =
108  isa<BlockArgument>(yieldedValue) &&
109  cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
110  padOp.getOperation();
111  // Is the yielded value an OpResult defined outside of the PadOp?
112  bool outsideOpResult =
113  isa<OpResult>(yieldedValue) &&
114  yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
115  bool invariantYieldedValue = outsideBbArg || outsideOpResult;
116  if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
117  // Padding with a constant: Create linalg.fill.
118  Dialect *arithDialect =
119  rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
120  Value fillValue =
121  arithDialect
122  ->materializeConstant(rewriter, constYieldedValue,
123  yieldedValue.getType(), yieldedValue.getLoc())
124  ->getResult(0);
125  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue),
126  ValueRange(dest));
127  return fillOp;
128  }
130  if (invariantYieldedValue) {
131  // Padding with an invariant value.
132  auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue),
133  ValueRange(dest));
134  return fillOp;
135  }
137  // Create linalg.generic.
138  SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
139  utils::IteratorType::parallel);
140  SmallVector<AffineMap> indexingMaps(
141  1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
142  auto genericOp = rewriter.create<linalg::GenericOp>(
143  loc, resultType, /*inputs=*/ValueRange(),
144  /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
145  indexingMaps, iteratorTypes);
146  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
147  resultType.getElementType(), loc);
148  rewriter.setInsertionPointToStart(body);
149  SmallVector<Value> bbArgReplacements;
150  for (int64_t i = 0; i < resultType.getRank(); ++i)
151  bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
152  rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
154  // Update terminator.
155  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
156  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
157  return genericOp;
158 }
161  Value value) {
162  auto tensorType = cast<RankedTensorType>(value.getType());
163  if (tensorType.hasStaticShape())
164  return {};
166  // Try to reify dynamic sizes.
167  ReifiedRankedShapedTypeDims reifiedShape;
168  if (isa<OpResult>(value) &&
169  succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
170  SmallVector<Value> dynSizes;
171  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
172  if (tensorType.isDynamicDim(i))
173  dynSizes.push_back(
174  reifiedShape[cast<OpResult>(value).getResultNumber()][i]
175  .get<Value>());
176  }
177  return dynSizes;
178  }
180  // Create tensor.dim ops.
181  SmallVector<Value> dynSizes;
182  for (int64_t i = 0; i < tensorType.getRank(); ++i) {
183  if (tensorType.isDynamicDim(i))
184  dynSizes.push_back(
185  b.create<DimOp>(value.getLoc(), value,
186  b.create<arith::ConstantIndexOp>(value.getLoc(), i)));
187  }
188  return dynSizes;
189 }
191 static Value
194  Attribute memorySpace = {}) {
195  OpBuilder::InsertionGuard g(rewriter);
196  auto tensorType = cast<RankedTensorType>(value.getType());
198  // Create buffer allocation.
199  auto memrefType =
201  tensorType, memorySpace));
202  SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
204  Value alloc;
205  if (options.allocOp ==
207  alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes);
208  if (options.emitDealloc) {
209  // Place deallocation at the end of the block.
210  rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
211  rewriter.create<memref::DeallocOp>(loc, alloc);
212  }
213  } else if (options.allocOp ==
215  alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
216  // No dealloc is needed.
217  }
219  return alloc;
220 }
224  PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
225  // tensor.pad does not have a destination operand.
226  assert(!options.bufferizeDestinationOnly && "invalid options");
228  OpBuilder::InsertionGuard g(rewriter);
229  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
230  Location loc = padOp.getLoc();
232  // Create buffer allocation.
233  Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
234  options, memorySpace);
235  rewriter.setInsertionPoint(padOp);
237  if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
238  // Create linalg.fill or linalg.generic. Not needed if there is no padding.
239  Operation *fillOp =
240  movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
241  rewriter.setInsertionPointAfter(fillOp);
242  }
244  // Create memcpy.
246  getMixedSizes(rewriter, loc, padOp.getSource());
247  SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
248  rewriter.getIndexAttr(1));
249  Value subview = rewriter.create<memref::SubViewOp>(
250  loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
251  createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
253  // Create bufferization.to_tensor with "restrict" and "writable". The returned
254  // tensor is a new buffer allocation, so it does not alias with any buffer.
255  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
256  loc, alloc, /*restrict=*/true, /*writable=*/true);
257  rewriter.replaceOp(padOp, toTensorOp);
258  return alloc;
259 }
263  vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
264  assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
265  "expected single masked op");
266  OpBuilder::InsertionGuard g(rewriter);
267  bufferization::BufferizationOptions bufferizationOptions;
268  Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
269  assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
271  // Bufferize maskable op. By default, place the buffer allocation right before
272  // the mask op.
274  rewriter, options, maskOp.getMaskableOp(), memorySpace,
275  /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
277  if (options.bufferizeDestinationOnly)
278  return alloc;
280  // Bufferize terminator.
281  rewriter.setInsertionPoint(yieldOp);
282  if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
283  rewriter, bufferizationOptions)))
284  return nullptr;
286  // Erase dead to_tensor ops inside of the mask op. This is necessary because
287  // there only be one op (apart from the terminator) inside the mask op.
288  // TODO: Remove dead to_tensor ops more aggressively during bufferization.
289  SmallVector<Operation *> toTensorOps;
290  maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
291  if (toTensorOp->getUses().empty())
292  toTensorOps.push_back(toTensorOp.getOperation());
293  });
294  for (Operation *op : toTensorOps)
295  rewriter.eraseOp(op);
297  // Bufferize mask op.
298  SmallVector<OpOperand *> resultUses;
299  for (Value result : maskOp.getResults())
300  if (isa<TensorType>(result.getType()))
301  for (OpOperand &use : result.getUses())
302  resultUses.push_back(&use);
303  rewriter.setInsertionPoint(maskOp);
304  if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
305  .bufferize(rewriter, bufferizationOptions)))
306  return nullptr;
308  // Set "restrict" attribute, indicating that no other tensor aliases with
309  // this tensor. That is because we just allocated a new buffer for the tensor.
310  for (OpOperand *resultUse : resultUses) {
311  auto toTensorOp =
312  resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
313  assert(toTensorOp && "expected to_tensor op");
314  rewriter.modifyOpInPlace(toTensorOp, [&]() {
315  toTensorOp.setRestrict(true);
316  toTensorOp.setWritable(true);
317  });
318  }
320  return alloc;
321 }
325  bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
326  Operation *insertionPoint) {
327  Location loc = allocTensorOp.getLoc();
328  OpBuilder::InsertionGuard g(rewriter);
329  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
330  bufferization::BufferizationOptions bufferizationOptions;
332  // Create buffer allocation.
334  rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
336  // Create bufferization.to_tensor with "restrict" and "writable". The returned
337  // tensor is a new buffer allocation, so it does not alias with any buffer.
338  Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
339  loc, alloc, /*restrict=*/true, /*writable=*/true);
340  rewriter.replaceOp(allocTensorOp, toTensorOp);
341  return alloc;
342 }
344 /// Lower tensor.from_elements to a sequence of chained tensor.insert.
346  RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
347  Location loc = fromElementsOp.getLoc();
348  RankedTensorType tensorType =
349  cast<RankedTensorType>(fromElementsOp.getType());
350  auto shape = tensorType.getShape();
352  // Create tensor.empty.
353  auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange());
355  // Case: tensor<elem_type>.
356  if (shape.empty()) {
357  Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
358  fromElementsOp, fromElementsOp.getElements().front(),
359  emptyOp.getResult(), ValueRange());
360  return res;
361  }
363  // Create constants for the range of possible indices [0, max{shape_i}).
364  auto maxDim = *llvm::max_element(shape);
365  SmallVector<Value, 2> constants;
366  constants.reserve(maxDim);
367  for (int i = 0; i < maxDim; ++i)
368  constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
370  // Traverse all elements and create tensor.insert ops.
371  auto elementIt = fromElementsOp.getElements().begin();
372  SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
373  Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
374  shape, constants, elementIt, indices);
376  // Replace tensor.from_elements.
377  rewriter.replaceOp(fromElementsOp, result);
378  return result.getDefiningOp();
379 }
381 /// Lower tensor.generate to linalg.generic.
384  tensor::GenerateOp generateOp) {
385  // Only ops with exactly one block are supported.
386  if (!generateOp.getBody().hasOneBlock())
387  return failure();
389  Location loc = generateOp.getLoc();
390  RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
392  // Create tensor.empty.
393  auto emptyOp =
394  rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
396  // Create linalg.generic.
397  SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
398  utils::IteratorType::parallel);
399  SmallVector<AffineMap> indexingMaps(
400  1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
401  auto genericOp = rewriter.create<linalg::GenericOp>(
402  loc, tensorType, /*inputs=*/ValueRange(),
403  /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
404  indexingMaps, iteratorTypes);
405  Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
406  tensorType.getElementType(), loc);
407  rewriter.setInsertionPointToStart(body);
408  SmallVector<Value> bbArgReplacements;
409  for (int64_t i = 0; i < tensorType.getRank(); ++i)
410  bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i));
411  rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
413  // Update terminator.
414  auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
415  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
417  // Replace tensor.generate.
418  rewriter.replaceOp(generateOp, genericOp->getResult(0));
419  return genericOp.getOperation();
420 }
422 /// Lower tensor.pad to linalg.generic + tensor.insert_slice.
425  tensor::PadOp padOp) {
426  // Only ops with exactly one block are supported.
427  if (!padOp.getBodyRegion().hasOneBlock())
428  return failure();
430  // Create tensor.empty.
431  Location loc = padOp.getLoc();
432  RankedTensorType resultType = padOp.getResultType();
433  ReifiedRankedShapedTypeDims reifiedShape;
434  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
435  return rewriter.notifyMatchFailure(
436  padOp, "failed to reify tensor.pad op result shape");
437  SmallVector<Value> dynamicSizes;
438  for (int64_t i = 0; i < resultType.getRank(); ++i)
439  if (resultType.isDynamicDim(i))
440  dynamicSizes.push_back(reifiedShape[0][i].get<Value>());
442  // If the `padOp` has a nofold attribute and all paddings are known to be 0,
443  // explicitly insert a `linalg.copy`.
444  if (padOp.getNofoldAttr() &&
445  llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
446  llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
447  using bufferization::AllocTensorOp;
448  Value allocated =
449  rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);
450  auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
451  padOp, padOp.getSource(), allocated);
452  return copyOp.getOperation();
453  }
455  Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes);
456  // Create linalg.fill or linalg.generic.
457  Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
458  rewriter.setInsertionPointAfter(fillOp);
460  // Create tensor::InsertSliceOp.
461  SmallVector<OpFoldResult> sliceSizes =
462  getMixedSizes(rewriter, loc, padOp.getSource());
463  SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
464  rewriter.getIndexAttr(1));
465  auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
466  padOp, padOp.getSource(), fillOp->getResult(0),
467  /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
468  return insertSliceOp.getOperation();
469 }
473  Operation *op, Attribute memorySpace, Operation *insertionPoint) {
474  using namespace bufferization;
476  // Call specialized overload for certain ops.
477  if (auto padOp = dyn_cast<tensor::PadOp>(op))
478  return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
479  if (auto maskOp = dyn_cast<vector::MaskOp>(op))
480  return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
481  if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
482  return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
484  // Only bufferizable ops are supported.
485  auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
486  if (!bufferizableOp)
487  return nullptr;
488  BufferizationOptions bufferizationOptions;
489  AnalysisState state(bufferizationOptions);
491 #ifndef NDEBUG
492  if (!options.bufferizeDestinationOnly) {
493  // Ops with nested tensor ops are not supported yet. At the moment, this
494  // function just bufferizes the given op itself, but not its body.
495  op->walk([&](Operation *nestedOp) {
496  if (op == nestedOp)
497  return;
498  if (llvm::any_of(nestedOp->getOperands(),
499  [](Value v) { return isa<TensorType>(v.getType()); }))
500  llvm_unreachable("ops with nested tensor ops are not supported yet");
501  if (llvm::any_of(nestedOp->getResults(),
502  [](Value v) { return isa<TensorType>(v.getType()); }))
503  llvm_unreachable("ops with nested tensor ops are not supported yet");
504  });
505  }
506 #endif // NDEBUG
508  // Gather tensor results.
509  SmallVector<OpResult> tensorResults;
510  for (OpResult result : op->getResults()) {
511  if (!isa<TensorType>(result.getType()))
512  continue;
513  // Unranked tensors are not supported
514  if (!isa<RankedTensorType>(result.getType()))
515  return nullptr;
516  // Ops that bufferize to an allocation are not supported.
517  if (bufferizableOp.bufferizesToAllocation(result))
518  return nullptr;
519  tensorResults.push_back(result);
520  }
522  // Gather all operands that should bufferize to a new allocation. I.e.,
523  // bufferize out-of-place.
524  SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
525  auto addOutOfPlaceOperand = [&](OpOperand *operand) {
526  if (!llvm::is_contained(outOfPlaceOperands, operand))
527  outOfPlaceOperands.push_back(operand);
528  };
529  for (OpResult result : tensorResults) {
530  AliasingOpOperandList aliasingOperands =
531  state.getAliasingOpOperands(result);
532  for (const AliasingOpOperand &operand : aliasingOperands) {
533  addOutOfPlaceOperand(operand.opOperand);
534  for (OpOperand &resultUse : result.getUses())
535  resultUses.push_back(&resultUse);
536  }
537  }
538  for (OpOperand &operand : op->getOpOperands()) {
539  if (!state.bufferizesToMemoryWrite(operand))
540  continue;
541  if (!isa<RankedTensorType>(operand.get().getType()))
542  continue;
543  addOutOfPlaceOperand(&operand);
544  }
545  // TODO: Support multiple buffers.
546  if (outOfPlaceOperands.size() != 1)
547  return nullptr;
549  // Allocate buffers.
550  OpBuilder::InsertionGuard g(rewriter);
551  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
552  SmallVector<Value> allocs;
553  for (OpOperand *operand : outOfPlaceOperands) {
555  rewriter, op->getLoc(), operand->get(), options, memorySpace);
556  allocs.push_back(alloc);
557  if (!state.findDefinitions(operand->get()).empty()) {
558  // Initialize buffer with a copy of the operand data. Not needed if the
559  // tensor is uninitialized.
560  createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
561  }
562  rewriter.modifyOpInPlace(op, [&]() {
563  auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
564  operand->set(toTensorOp);
565  if (options.bufferizeDestinationOnly) {
566  rewriter.modifyOpInPlace(toTensorOp, [&]() {
567  toTensorOp.setRestrict(true);
568  toTensorOp.setWritable(true);
569  });
570  }
571  });
572  }
574  if (options.bufferizeDestinationOnly)
575  return allocs.front();
577  // Bufferize the op.
578  rewriter.setInsertionPoint(op);
579  if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
580  return nullptr;
582  // Set "restrict" attribute, indicating that no other tensor aliases with
583  // this tensor. That is because we just allocated a new buffer for the tensor.
584  for (OpOperand *resultUse : resultUses) {
585  auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
586  assert(toTensorOp && "expected to_tensor op");
587  rewriter.modifyOpInPlace(toTensorOp, [&]() {
588  toTensorOp.setRestrict(true);
589  toTensorOp.setWritable(true);
590  });
591  }
592  return allocs.front();
593 }
595 namespace {
597 template <typename OpTy>
598 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
599  PatternRewriter &rewriter) {
600  return linalg::rewriteInDestinationPassingStyle(rewriter, op);
601 }
603 } // namespace
606  RewritePatternSet &patterns) {
607  patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
608  patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
609  patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
610 }
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
MLIRContext * getContext() const
Definition: Builders.h:55
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.