MLIR 22.0.0git
ConvertToDestinationStyle.cpp
Go to the documentation of this file.
1//===- ConvertToDestinationStyle.cpp - Convert non-DPS to DPS ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns to convert non-DPS ops to DPS ops. New
10// tensor.empty ops are inserted as a destination. Such tensor.empty can be
11// eliminated with "empty tensor elimination", allowing them to bufferize
12// without an allocation (assuming there are no further conflicts).
13//
14//===----------------------------------------------------------------------===//
15//
23#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26
27using namespace mlir;
28using namespace mlir::tensor;
29
30// Implements backtracking to traverse indices of the output buffer while
31// iterating over op.elements().
32static Value createInserts(RewriterBase &rewriter, Location loc, int dim,
33 Value destination, ArrayRef<int64_t> shape,
34 ArrayRef<Value> constants,
35 OperandRange::iterator &elementIt,
37 if (dim == static_cast<int>(shape.size()) - 1) {
38 for (int i = 0; i < shape.back(); ++i) {
39 indices.back() = constants[i];
40 destination = tensor::InsertOp::create(rewriter, loc, *elementIt,
41 destination, indices);
42 ++elementIt;
43 }
44 return destination;
45 }
46 for (int i = 0; i < shape[dim]; ++i) {
47 indices[dim] = constants[i];
48 destination = createInserts(rewriter, loc, dim + 1, destination, shape,
49 constants, elementIt, indices);
50 }
51 return destination;
52}
53
54/// Create a memcpy from the given source tensor to the given destination
55/// memref. The copy op type can be specified in the `options`.
56static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
57 Value memrefDest,
59 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
60 assert(tensorType && "expected ranked tensor");
61 assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
62
63 switch (options.memcpyOp) {
64 case linalg::BufferizeToAllocationOptions::MemcpyOp::
65 MaterializeInDestination: {
66 // Note: This is the preferred way of memcpy'ing because no layout map
67 // and/or memory space must be specified for the source.
68 auto materializeOp = bufferization::MaterializeInDestinationOp::create(
69 b, loc, tensorSource, memrefDest);
70 materializeOp.setWritable(true);
71 } break;
73 // TODO: Support custom memory space on source.
74 // We do not know the layout map of the source yet, so use a fully dynamic
75 // layout for best compatibility.
76 Value toBuffer = bufferization::ToBufferOp::create(
77 b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
78 tensorSource, /*read_only=*/true);
79 memref::CopyOp::create(b, loc, toBuffer, memrefDest);
80 } break;
82 // TODO: Support custom memory space on source.
83 // We do not know the layout map of the source yet, so use a fully dynamic
84 // layout for best compatibility.
85 Value toBuffer = bufferization::ToBufferOp::create(
86 b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
87 tensorSource, /*read_only=*/true);
88 linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
89 } break;
90 };
91}
92
94 Location loc, PadOp padOp,
95 Value dest) {
96 OpBuilder::InsertionGuard g(rewriter);
97 RankedTensorType resultType = padOp.getResultType();
98
99 // Examine the yielded value to decide if a linalg.generic is neede or a
100 // linalg.fill is sufficient.
101 Value yieldedValue =
102 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
103 Attribute constYieldedValue;
104 // Is the yielded value a bbArg defined outside of the PadOp?
105 bool outsideBbArg =
106 isa<BlockArgument>(yieldedValue) &&
107 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
108 padOp.getOperation();
109 // Is the yielded value an OpResult defined outside of the PadOp?
110 bool outsideOpResult =
111 isa<OpResult>(yieldedValue) &&
112 yieldedValue.getDefiningOp()->getParentOp() != padOp.getOperation();
113 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
114 if (matchPattern(yieldedValue, m_Constant(&constYieldedValue))) {
115 // Padding with a constant: Create linalg.fill.
116 Dialect *arithDialect =
117 rewriter.getContext()->getLoadedDialect<arith::ArithDialect>();
118 Value fillValue =
119 arithDialect
120 ->materializeConstant(rewriter, constYieldedValue,
121 yieldedValue.getType(), yieldedValue.getLoc())
122 ->getResult(0);
123 auto fillOp = linalg::FillOp::create(rewriter, loc, ValueRange(fillValue),
124 ValueRange(dest));
125 return fillOp;
126 }
127
128 if (invariantYieldedValue) {
129 // Padding with an invariant value.
130 auto fillOp = linalg::FillOp::create(
131 rewriter, loc, ValueRange(yieldedValue), ValueRange(dest));
132 return fillOp;
133 }
134
135 // Create linalg.generic.
136 SmallVector<utils::IteratorType> iteratorTypes(resultType.getRank(),
137 utils::IteratorType::parallel);
138 SmallVector<AffineMap> indexingMaps(
139 1, rewriter.getMultiDimIdentityMap(resultType.getRank()));
140 auto genericOp = linalg::GenericOp::create(
141 rewriter, loc, resultType, /*inputs=*/ValueRange(),
142 /*outputs=*/ValueRange{dest}, /*indexingMaps=*/
143 indexingMaps, iteratorTypes);
144 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
145 resultType.getElementType(), loc);
146 rewriter.setInsertionPointToStart(body);
147 SmallVector<Value> bbArgReplacements;
148 for (int64_t i = 0; i < resultType.getRank(); ++i)
149 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
150 rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements);
151
152 // Update terminator.
153 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
154 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
155 return genericOp;
156}
157
159 Value value) {
160 auto tensorType = cast<RankedTensorType>(value.getType());
161 if (tensorType.hasStaticShape())
162 return {};
163
164 // Try to reify dynamic sizes.
165 ReifiedRankedShapedTypeDims reifiedShape;
166 if (isa<OpResult>(value) &&
167 succeeded(reifyResultShapes(b, value.getDefiningOp(), reifiedShape))) {
168 SmallVector<Value> dynSizes;
169 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
170 if (tensorType.isDynamicDim(i))
171 dynSizes.push_back(cast<Value>(
172 reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
173 }
174 return dynSizes;
175 }
176
177 // Create tensor.dim ops.
178 SmallVector<Value> dynSizes;
179 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
180 if (tensorType.isDynamicDim(i))
181 dynSizes.push_back(
182 DimOp::create(b, value.getLoc(), value,
184 }
185 return dynSizes;
186}
187
188static Value
191 Attribute memorySpace = {}) {
192 OpBuilder::InsertionGuard g(rewriter);
193 auto tensorType = cast<RankedTensorType>(value.getType());
194
195 // Create buffer allocation.
196 auto memrefType =
197 cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
198 tensorType, memorySpace));
199 SmallVector<Value> dynamicSizes = reifyOrComputeDynamicSizes(rewriter, value);
200
201 Value alloc;
202 if (options.allocOp ==
204 alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes);
205 if (options.emitDealloc) {
206 // Place deallocation at the end of the block.
207 rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
208 memref::DeallocOp::create(rewriter, loc, alloc);
209 }
210 } else if (options.allocOp ==
212 alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes);
213 // No dealloc is needed.
214 }
215
216 return alloc;
217}
218
221 PadOp padOp, Attribute memorySpace, Operation *insertionPoint) {
222 // tensor.pad does not have a destination operand.
223 assert(!options.bufferizeDestinationOnly && "invalid options");
224
225 OpBuilder::InsertionGuard g(rewriter);
226 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
227 Location loc = padOp.getLoc();
228
229 // Create buffer allocation.
230 Value alloc = createAllocationForTensor(rewriter, loc, padOp.getResult(),
231 options, memorySpace);
232 rewriter.setInsertionPoint(padOp);
233
234 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
235 // Create linalg.fill or linalg.generic. Not needed if there is no padding.
236 Operation *fillOp =
237 movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
238 rewriter.setInsertionPointAfter(fillOp);
239 }
240
241 // Create memcpy.
243 getMixedSizes(rewriter, loc, padOp.getSource());
244 SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(),
245 rewriter.getIndexAttr(1));
246 Value subview = memref::SubViewOp::create(
247 rewriter, loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides);
248 createMemcpy(rewriter, loc, padOp.getSource(), subview, options);
249
250 // Create bufferization.to_tensor with "restrict" and "writable". The returned
251 // tensor is a new buffer allocation, so it does not alias with any buffer.
252 Value toTensorOp = bufferization::ToTensorOp::create(
253 rewriter, loc, padOp.getResult().getType(), alloc, /*restrict=*/true,
254 /*writable=*/true);
255 rewriter.replaceOp(padOp, toTensorOp);
256 return alloc;
257}
258
261 vector::MaskOp maskOp, Attribute memorySpace, Operation *insertionPoint) {
262 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
263 "expected single masked op");
264 OpBuilder::InsertionGuard g(rewriter);
265
266 // Should the bufferization options and state be function arguments?
267 bufferization::BufferizationOptions bufferizationOptions;
268 bufferization::BufferizationState bufferizationState;
269
270 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
271 assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
272
273 // Bufferize maskable op. By default, place the buffer allocation right before
274 // the mask op.
276 rewriter, options, maskOp.getMaskableOp(), memorySpace,
277 /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
278
279 if (options.bufferizeDestinationOnly)
280 return alloc;
281
282 // Bufferize terminator.
283 rewriter.setInsertionPoint(yieldOp);
284 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
285 rewriter, bufferizationOptions, bufferizationState)))
286 return nullptr;
287
288 // Erase dead to_tensor ops inside of the mask op. This is necessary because
289 // there only be one op (apart from the terminator) inside the mask op.
290 // TODO: Remove dead to_tensor ops more aggressively during bufferization.
291 SmallVector<Operation *> toTensorOps;
292 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
293 if (toTensorOp->getUses().empty())
294 toTensorOps.push_back(toTensorOp.getOperation());
295 });
296 for (Operation *op : toTensorOps)
297 rewriter.eraseOp(op);
298
299 // Bufferize mask op.
300 SmallVector<OpOperand *> resultUses;
301 for (Value result : maskOp.getResults())
302 if (isa<TensorType>(result.getType()))
303 for (OpOperand &use : result.getUses())
304 resultUses.push_back(&use);
305 rewriter.setInsertionPoint(maskOp);
306 if (failed(
307 cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
308 .bufferize(rewriter, bufferizationOptions, bufferizationState)))
309 return nullptr;
310
311 // Set "restrict" attribute, indicating that no other tensor aliases with
312 // this tensor. That is because we just allocated a new buffer for the tensor.
313 for (OpOperand *resultUse : resultUses) {
314 auto toTensorOp =
315 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
316 assert(toTensorOp && "expected to_tensor op");
317 rewriter.modifyOpInPlace(toTensorOp, [&]() {
318 toTensorOp.setRestrict(true);
319 toTensorOp.setWritable(true);
320 });
321 }
322
323 return alloc;
324}
325
328 bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
329 Operation *insertionPoint) {
330 Location loc = allocTensorOp.getLoc();
331 OpBuilder::InsertionGuard g(rewriter);
332 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
333 bufferization::BufferizationOptions bufferizationOptions;
334
335 // Create buffer allocation.
337 rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
338
339 // Create bufferization.to_tensor with "restrict" and "writable". The returned
340 // tensor is a new buffer allocation, so it does not alias with any buffer.
341 Value toTensorOp = bufferization::ToTensorOp::create(
342 rewriter, loc, allocTensorOp.getResult().getType(), alloc,
343 /*restrict=*/true,
344 /*writable=*/true);
345 rewriter.replaceOp(allocTensorOp, toTensorOp);
346 return alloc;
347}
348
349/// Lower tensor.from_elements to a sequence of chained tensor.insert.
351 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
352 Location loc = fromElementsOp.getLoc();
353 RankedTensorType tensorType =
354 cast<RankedTensorType>(fromElementsOp.getType());
355 auto shape = tensorType.getShape();
356
357 // Create tensor.empty.
358 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType, ValueRange());
359
360 // Case: tensor<elem_type>.
361 if (shape.empty()) {
362 Operation *res = rewriter.replaceOpWithNewOp<tensor::InsertOp>(
363 fromElementsOp, fromElementsOp.getElements().front(),
364 emptyOp.getResult(), ValueRange());
365 return res;
366 }
367
368 // Create constants for the range of possible indices [0, max{shape_i}).
369 auto maxDim = *llvm::max_element(shape);
370 SmallVector<Value, 2> constants;
371 constants.reserve(maxDim);
372 for (int i = 0; i < maxDim; ++i)
373 constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i));
374
375 // Traverse all elements and create tensor.insert ops.
376 auto elementIt = fromElementsOp.getElements().begin();
377 SmallVector<Value, 2> indices(tensorType.getRank(), constants[0]);
378 Value result = createInserts(rewriter, loc, /*dim=*/0, emptyOp.getResult(),
379 shape, constants, elementIt, indices);
380
381 // Replace tensor.from_elements.
382 rewriter.replaceOp(fromElementsOp, result);
383 return result.getDefiningOp();
384}
385
386/// Lower tensor.generate to linalg.generic.
387FailureOr<Operation *>
389 tensor::GenerateOp generateOp) {
390 // Only ops with exactly one block are supported.
391 if (!generateOp.getBody().hasOneBlock())
392 return failure();
393
394 Location loc = generateOp.getLoc();
395 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
396
397 // Create tensor.empty.
398 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
399 generateOp.getDynamicExtents());
400
401 // Create linalg.generic.
402 SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(),
403 utils::IteratorType::parallel);
404 SmallVector<AffineMap> indexingMaps(
405 1, rewriter.getMultiDimIdentityMap(tensorType.getRank()));
406 auto genericOp = linalg::GenericOp::create(
407 rewriter, loc, tensorType, /*inputs=*/ValueRange(),
408 /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/
409 indexingMaps, iteratorTypes);
410 Block *body = rewriter.createBlock(&genericOp->getRegion(0), {},
411 tensorType.getElementType(), loc);
412 rewriter.setInsertionPointToStart(body);
413 SmallVector<Value> bbArgReplacements;
414 for (int64_t i = 0; i < tensorType.getRank(); ++i)
415 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
416 rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
417
418 // Update terminator.
419 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
420 rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
421
422 // Replace tensor.generate.
423 rewriter.replaceOp(generateOp, genericOp->getResult(0));
424 return genericOp.getOperation();
425}
426
427/// Lower tensor.pad to linalg.generic + tensor.insert_slice.
428FailureOr<Operation *>
430 tensor::PadOp padOp) {
431 // Only ops with exactly one block are supported.
432 if (!padOp.getBodyRegion().hasOneBlock())
433 return failure();
434
435 // Create tensor.empty.
436 Location loc = padOp.getLoc();
437 RankedTensorType resultType = padOp.getResultType();
438 ReifiedRankedShapedTypeDims reifiedShape;
439 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
440 return rewriter.notifyMatchFailure(
441 padOp, "failed to reify tensor.pad op result shape");
442 SmallVector<Value> dynamicSizes;
443 for (int64_t i = 0; i < resultType.getRank(); ++i)
444 if (resultType.isDynamicDim(i))
445 dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
446
447 // If the `padOp` has a nofold attribute and all paddings are known to be 0,
448 // explicitly insert a `linalg.copy`.
449 if (padOp.getNofoldAttr() &&
450 llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&
451 llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
452 using bufferization::AllocTensorOp;
453 Value allocated =
454 AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes);
455 auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>(
456 padOp, padOp.getSource(), allocated);
457 return copyOp.getOperation();
458 }
459
460 Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes);
461 // Create linalg.fill or linalg.generic.
462 Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty);
463 rewriter.setInsertionPointAfter(fillOp);
464
465 // Create tensor::InsertSliceOp.
466 SmallVector<OpFoldResult> sliceSizes =
467 getMixedSizes(rewriter, loc, padOp.getSource());
468 SmallVector<OpFoldResult> sliceStrides(resultType.getRank(),
469 rewriter.getIndexAttr(1));
470 auto insertSliceOp = rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
471 padOp, padOp.getSource(), fillOp->getResult(0),
472 /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);
473 return insertSliceOp.getOperation();
474}
475
478 Operation *op, Attribute memorySpace, Operation *insertionPoint) {
479 using namespace bufferization;
480
481 // Call specialized overload for certain ops.
482 if (auto padOp = dyn_cast<tensor::PadOp>(op))
483 return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
484 if (auto maskOp = dyn_cast<vector::MaskOp>(op))
485 return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
486 if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
487 return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
488
489 // Only bufferizable ops are supported.
490 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
491 if (!bufferizableOp)
492 return nullptr;
493
494 // Should the bufferization options and states be function arguments?
495 BufferizationOptions bufferizationOptions;
496 AnalysisState analysisState(bufferizationOptions);
497 BufferizationState bufferizationState;
498
499#ifndef NDEBUG
500 if (!options.bufferizeDestinationOnly) {
501 // Ops with nested tensor ops are not supported yet. At the moment, this
502 // function just bufferizes the given op itself, but not its body.
503 op->walk([&](Operation *nestedOp) {
504 if (op == nestedOp)
505 return;
506 if (llvm::any_of(nestedOp->getOperands(),
507 [](Value v) { return isa<TensorType>(v.getType()); }))
508 llvm_unreachable("ops with nested tensor ops are not supported yet");
509 if (llvm::any_of(nestedOp->getResults(),
510 [](Value v) { return isa<TensorType>(v.getType()); }))
511 llvm_unreachable("ops with nested tensor ops are not supported yet");
512 });
513 }
514#endif // NDEBUG
515
516 // Gather tensor results.
517 SmallVector<OpResult> tensorResults;
518 for (OpResult result : op->getResults()) {
519 if (!isa<TensorType>(result.getType()))
520 continue;
521 // Unranked tensors are not supported
522 if (!isa<RankedTensorType>(result.getType()))
523 return nullptr;
524 // Ops that bufferize to an allocation are not supported.
525 if (bufferizableOp.bufferizesToAllocation(result))
526 return nullptr;
527 tensorResults.push_back(result);
528 }
529
530 // Gather all operands that should bufferize to a new allocation. I.e.,
531 // bufferize out-of-place.
532 SmallVector<OpOperand *> outOfPlaceOperands, resultUses;
533 auto addOutOfPlaceOperand = [&](OpOperand *operand) {
534 if (!llvm::is_contained(outOfPlaceOperands, operand))
535 outOfPlaceOperands.push_back(operand);
536 };
537 for (OpResult result : tensorResults) {
538 AliasingOpOperandList aliasingOperands =
539 analysisState.getAliasingOpOperands(result);
540 for (const AliasingOpOperand &operand : aliasingOperands) {
541 addOutOfPlaceOperand(operand.opOperand);
542 for (OpOperand &resultUse : result.getUses())
543 resultUses.push_back(&resultUse);
544 }
545 }
546 for (OpOperand &operand : op->getOpOperands()) {
547 if (!analysisState.bufferizesToMemoryWrite(operand))
548 continue;
549 if (!isa<RankedTensorType>(operand.get().getType()))
550 continue;
551 addOutOfPlaceOperand(&operand);
552 }
553 // TODO: Support multiple buffers.
554 if (outOfPlaceOperands.size() != 1)
555 return nullptr;
556
557 // Allocate buffers.
558 OpBuilder::InsertionGuard g(rewriter);
559 rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
560 SmallVector<Value> allocs;
561 for (OpOperand *operand : outOfPlaceOperands) {
563 rewriter, op->getLoc(), operand->get(), options, memorySpace);
564 allocs.push_back(alloc);
565 if (!analysisState.findDefinitions(operand).empty()) {
566 // Initialize buffer with a copy of the operand data. Not needed if the
567 // tensor is uninitialized.
568 createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
569 }
570 rewriter.modifyOpInPlace(op, [&]() {
571 auto toTensorOp = ToTensorOp::create(rewriter, op->getLoc(),
572 operand->get().getType(), alloc);
573 operand->set(toTensorOp);
574 if (options.bufferizeDestinationOnly) {
575 rewriter.modifyOpInPlace(toTensorOp, [&]() {
576 toTensorOp.setRestrict(true);
577 toTensorOp.setWritable(true);
578 });
579 }
580 });
581 }
582
583 if (options.bufferizeDestinationOnly)
584 return allocs.front();
585
586 // Bufferize the op.
587 rewriter.setInsertionPoint(op);
588 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
589 bufferizationState)))
590 return nullptr;
591
592 // Set "restrict" attribute, indicating that no other tensor aliases with
593 // this tensor. That is because we just allocated a new buffer for the tensor.
594 for (OpOperand *resultUse : resultUses) {
595 auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
596 assert(toTensorOp && "expected to_tensor op");
597 rewriter.modifyOpInPlace(toTensorOp, [&]() {
598 toTensorOp.setRestrict(true);
599 toTensorOp.setWritable(true);
600 });
601 }
602 return allocs.front();
603}
604
605namespace {
606
607template <typename OpTy>
608LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
609 PatternRewriter &rewriter) {
610 return linalg::rewriteInDestinationPassingStyle(rewriter, op);
611}
612
613} // namespace
614
617 patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
618 patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
619 patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
620}
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
MLIRContext * getContext() const
Definition Builders.h:56
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369