26#include "llvm/ADT/STLExtras.h"
36 OperandRange::iterator &elementIt,
38 if (dim ==
static_cast<int>(
shape.size()) - 1) {
39 for (
int i = 0; i <
shape.back(); ++i) {
41 destination = tensor::InsertOp::create(rewriter, loc, *elementIt,
47 for (
int i = 0; i <
shape[dim]; ++i) {
60 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.
getType());
61 assert(tensorType &&
"expected ranked tensor");
62 assert(isa<MemRefType>(memrefDest.
getType()) &&
"expected ranked memref");
65 case linalg::BufferizeToAllocationOptions::MemcpyOp::
66 MaterializeInDestination: {
69 auto materializeOp = bufferization::MaterializeInDestinationOp::create(
70 b, loc, tensorSource, memrefDest);
71 materializeOp.setWritable(
true);
77 Value toBuffer = bufferization::ToBufferOp::create(
78 b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
80 memref::CopyOp::create(
b, loc, toBuffer, memrefDest);
86 Value toBuffer = bufferization::ToBufferOp::create(
87 b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType),
89 linalg::CopyOp::create(
b, loc, toBuffer, memrefDest);
98 RankedTensorType resultType = padOp.getResultType();
108 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
112 isa<BlockArgument>(yieldedValue) &&
113 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
114 padOp.getOperation();
116 bool outsideOpResult =
117 isa<OpResult>(yieldedValue) &&
119 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
129 auto fillOp = linalg::FillOp::create(rewriter, loc,
ValueRange(fillValue),
131 fillOp->setDiscardableAttrs(preservedAttrs);
135 if (invariantYieldedValue) {
137 auto fillOp = linalg::FillOp::create(
139 fillOp->setDiscardableAttrs(preservedAttrs);
145 utils::IteratorType::parallel);
148 auto genericOp = linalg::GenericOp::create(
151 indexingMaps, iteratorTypes);
152 genericOp->setDiscardableAttrs(preservedAttrs);
154 resultType.getElementType(), loc);
157 for (
int64_t i = 0; i < resultType.getRank(); ++i)
158 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
159 rewriter.
mergeBlocks(padOp.getBody(), body, bbArgReplacements);
169 auto tensorType = cast<RankedTensorType>(value.
getType());
170 if (tensorType.hasStaticShape())
175 if (isa<OpResult>(value) &&
178 for (
int64_t i = 0; i < tensorType.getRank(); ++i) {
179 if (tensorType.isDynamicDim(i))
180 dynSizes.push_back(cast<Value>(
181 reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
188 for (
int64_t i = 0; i < tensorType.getRank(); ++i) {
189 if (tensorType.isDynamicDim(i))
191 DimOp::create(
b, value.
getLoc(), value,
202 auto tensorType = cast<RankedTensorType>(value.
getType());
206 cast<MemRefType>(bufferization::getMemRefTypeWithStaticIdentityLayout(
207 tensorType, memorySpace));
213 alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes);
217 memref::DeallocOp::create(rewriter, loc, alloc);
221 alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes);
232 assert(!
options.bufferizeDestinationOnly &&
"invalid options");
243 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
255 Value subview = memref::SubViewOp::create(
256 rewriter, loc, alloc, padOp.getMixedLowPad(), sizes, strides);
261 Value toTensorOp = bufferization::ToTensorOp::create(
262 rewriter, loc, padOp.getResult().getType(), alloc,
true,
271 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
272 "expected single masked op");
276 bufferization::BufferizationOptions bufferizationOptions;
277 bufferization::BufferizationState bufferizationState;
279 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
280 assert(isa<vector::YieldOp>(yieldOp) &&
"expected yield op terminator");
285 rewriter,
options, maskOp.getMaskableOp(), memorySpace,
286 insertionPoint ? insertionPoint : maskOp);
288 if (
options.bufferizeDestinationOnly)
293 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
294 rewriter, bufferizationOptions, bufferizationState)))
301 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
302 if (toTensorOp->getUses().empty())
303 toTensorOps.push_back(toTensorOp.getOperation());
311 if (isa<TensorType>(
result.getType()))
313 resultUses.push_back(&use);
316 cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
317 .bufferize(rewriter, bufferizationOptions, bufferizationState)))
322 for (
OpOperand *resultUse : resultUses) {
324 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
325 assert(toTensorOp &&
"expected to_tensor op");
327 toTensorOp.setRestrict(
true);
328 toTensorOp.setWritable(
true);
337 bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace,
339 Location loc = allocTensorOp.getLoc();
342 bufferization::BufferizationOptions bufferizationOptions;
346 rewriter, loc, allocTensorOp.getResult(),
options, memorySpace);
350 Value toTensorOp = bufferization::ToTensorOp::create(
351 rewriter, loc, allocTensorOp.getResult().getType(), alloc,
354 rewriter.
replaceOp(allocTensorOp, toTensorOp);
360 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
361 Location loc = fromElementsOp.getLoc();
362 RankedTensorType tensorType =
363 cast<RankedTensorType>(fromElementsOp.getType());
364 auto shape = tensorType.getShape();
367 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
ValueRange());
372 fromElementsOp, fromElementsOp.getElements().front(),
378 auto maxDim = *llvm::max_element(
shape);
380 constants.reserve(maxDim);
381 for (
int i = 0; i < maxDim; ++i)
385 auto elementIt = fromElementsOp.getElements().begin();
392 return result.getDefiningOp();
396FailureOr<Operation *>
398 tensor::GenerateOp generateOp) {
400 if (!generateOp.getBody().hasOneBlock())
404 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
407 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
408 generateOp.getDynamicExtents());
412 utils::IteratorType::parallel);
415 auto genericOp = linalg::GenericOp::create(
418 indexingMaps, iteratorTypes);
420 tensorType.getElementType(), loc);
423 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
424 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
425 rewriter.
mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
428 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
432 rewriter.
replaceOp(generateOp, genericOp->getResult(0));
433 return genericOp.getOperation();
437FailureOr<Operation *>
439 tensor::PadOp padOp) {
441 if (!padOp.getBodyRegion().hasOneBlock())
446 RankedTensorType resultType = padOp.getResultType();
450 padOp,
"failed to reify tensor.pad op result shape");
452 for (
int64_t i = 0; i < resultType.getRank(); ++i)
453 if (resultType.isDynamicDim(i))
454 dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
458 if (padOp.getNofoldAttr() &&
461 using bufferization::AllocTensorOp;
463 AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes);
465 padOp, padOp.getSource(), allocated);
466 return copyOp.getOperation();
469 Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes);
480 padOp, padOp.getSource(), fillOp->
getResult(0),
481 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
482 return insertSliceOp.getOperation();
491 if (
auto padOp = dyn_cast<tensor::PadOp>(op))
493 if (
auto maskOp = dyn_cast<vector::MaskOp>(op))
495 if (
auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
499 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
506 BufferizationState bufferizationState;
509 if (!
options.bufferizeDestinationOnly) {
516 [](
Value v) { return isa<TensorType>(v.getType()); }))
517 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
519 [](
Value v) { return isa<TensorType>(v.getType()); }))
520 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
528 if (!isa<TensorType>(
result.getType()))
531 if (!isa<RankedTensorType>(
result.getType()))
534 if (bufferizableOp.bufferizesToAllocation(
result))
536 tensorResults.push_back(
result);
542 auto addOutOfPlaceOperand = [&](
OpOperand *operand) {
543 if (!llvm::is_contained(outOfPlaceOperands, operand))
544 outOfPlaceOperands.push_back(operand);
547 AliasingOpOperandList aliasingOperands =
548 analysisState.getAliasingOpOperands(
result);
549 for (
const AliasingOpOperand &operand : aliasingOperands) {
550 addOutOfPlaceOperand(operand.opOperand);
552 resultUses.push_back(&resultUse);
556 if (!analysisState.bufferizesToMemoryWrite(operand))
558 if (!isa<RankedTensorType>(operand.get().getType()))
560 addOutOfPlaceOperand(&operand);
563 if (outOfPlaceOperands.size() != 1)
570 for (
OpOperand *operand : outOfPlaceOperands) {
572 rewriter, op->
getLoc(), operand->get(),
options, memorySpace);
573 allocs.push_back(alloc);
574 if (!analysisState.findDefinitions(operand).empty()) {
580 auto toTensorOp = ToTensorOp::create(rewriter, op->
getLoc(),
581 operand->get().getType(), alloc);
582 operand->set(toTensorOp);
583 if (
options.bufferizeDestinationOnly) {
585 toTensorOp.setRestrict(
true);
586 toTensorOp.setWritable(
true);
592 if (
options.bufferizeDestinationOnly)
593 return allocs.front();
597 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
598 bufferizationState)))
603 for (
OpOperand *resultUse : resultUses) {
604 auto toTensorOp = resultUse->get().
getDefiningOp<ToTensorOp>();
605 assert(toTensorOp &&
"expected to_tensor op");
607 toTensorOp.setRestrict(
true);
608 toTensorOp.setWritable(
true);
611 return allocs.front();
616template <
typename OpTy>
617LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
626 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
627 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
628 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)