26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
37 OperandRange::iterator &elementIt,
39 if (dim ==
static_cast<int>(shape.size()) - 1) {
40 for (
int i = 0; i < shape.back(); ++i) {
41 indices.back() = constants[i];
42 destination = rewriter.
create<tensor::InsertOp>(loc, *elementIt,
43 destination, indices);
48 for (
int i = 0; i < shape[dim]; ++i) {
49 indices[dim] = constants[i];
50 destination =
createInserts(rewriter, loc, dim + 1, destination, shape,
51 constants, elementIt, indices);
61 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.
getType());
62 assert(tensorType &&
"expected ranked tensor");
63 assert(isa<MemRefType>(memrefDest.
getType()) &&
"expected ranked memref");
70 auto materializeOp = b.
create<bufferization::MaterializeInDestinationOp>(
71 loc, tensorSource, memrefDest);
72 materializeOp.setWritable(
true);
78 Value toMemref = b.
create<bufferization::ToMemrefOp>(
81 b.
create<memref::CopyOp>(loc, toMemref, memrefDest);
87 Value toMemref = b.
create<bufferization::ToMemrefOp>(
90 b.
create<linalg::CopyOp>(loc, toMemref, memrefDest);
99 RankedTensorType resultType = padOp.getResultType();
104 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
108 isa<BlockArgument>(yieldedValue) &&
109 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
110 padOp.getOperation();
112 bool outsideOpResult =
113 isa<OpResult>(yieldedValue) &&
115 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
130 if (invariantYieldedValue) {
132 auto fillOp = rewriter.
create<linalg::FillOp>(loc,
ValueRange(yieldedValue),
139 utils::IteratorType::parallel);
142 auto genericOp = rewriter.
create<linalg::GenericOp>(
145 indexingMaps, iteratorTypes);
147 resultType.getElementType(), loc);
150 for (int64_t i = 0; i < resultType.getRank(); ++i)
151 bbArgReplacements.push_back(rewriter.
create<linalg::IndexOp>(loc, i));
152 rewriter.
mergeBlocks(padOp.getBody(), body, bbArgReplacements);
155 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
162 auto tensorType = cast<RankedTensorType>(value.
getType());
163 if (tensorType.hasStaticShape())
168 if (isa<OpResult>(value) &&
171 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
172 if (tensorType.isDynamicDim(i))
174 reifiedShape[cast<OpResult>(value).getResultNumber()][i]
182 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
183 if (tensorType.isDynamicDim(i))
196 auto tensorType = cast<RankedTensorType>(value.
getType());
201 tensorType, memorySpace));
207 alloc = rewriter.
create<memref::AllocOp>(loc, memrefType, dynamicSizes);
211 rewriter.
create<memref::DeallocOp>(loc, alloc);
215 alloc = rewriter.
create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
226 assert(!
options.bufferizeDestinationOnly &&
"invalid options");
237 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
249 Value subview = rewriter.
create<memref::SubViewOp>(
250 loc, alloc, padOp.getMixedLowPad(), sizes, strides);
255 Value toTensorOp = rewriter.
create<bufferization::ToTensorOp>(
256 loc, alloc,
true,
true);
264 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
265 "expected single masked op");
268 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
269 assert(isa<vector::YieldOp>(yieldOp) &&
"expected yield op terminator");
274 rewriter,
options, maskOp.getMaskableOp(), memorySpace,
275 insertionPoint ? insertionPoint : maskOp);
277 if (
options.bufferizeDestinationOnly)
282 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
283 rewriter, bufferizationOptions)))
290 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
291 if (toTensorOp->getUses().empty())
292 toTensorOps.push_back(toTensorOp.getOperation());
299 for (
Value result : maskOp.getResults())
300 if (isa<TensorType>(result.getType()))
302 resultUses.push_back(&use);
304 if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
305 .bufferize(rewriter, bufferizationOptions)))
310 for (
OpOperand *resultUse : resultUses) {
312 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
313 assert(toTensorOp &&
"expected to_tensor op");
315 toTensorOp.setRestrict(
true);
316 toTensorOp.setWritable(
true);
325 bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace,
327 Location loc = allocTensorOp.getLoc();
334 rewriter, loc, allocTensorOp.getResult(),
options, memorySpace);
338 Value toTensorOp = rewriter.
create<bufferization::ToTensorOp>(
339 loc, alloc,
true,
true);
340 rewriter.
replaceOp(allocTensorOp, toTensorOp);
346 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
347 Location loc = fromElementsOp.getLoc();
348 RankedTensorType tensorType =
349 cast<RankedTensorType>(fromElementsOp.getType());
350 auto shape = tensorType.getShape();
358 fromElementsOp, fromElementsOp.getElements().front(),
364 auto maxDim = *llvm::max_element(shape);
366 constants.reserve(maxDim);
367 for (
int i = 0; i < maxDim; ++i)
371 auto elementIt = fromElementsOp.getElements().begin();
374 shape, constants, elementIt, indices);
377 rewriter.
replaceOp(fromElementsOp, result);
382 FailureOr<Operation *>
384 tensor::GenerateOp generateOp) {
386 if (!generateOp.getBody().hasOneBlock())
390 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
394 rewriter.
create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
398 utils::IteratorType::parallel);
401 auto genericOp = rewriter.
create<linalg::GenericOp>(
404 indexingMaps, iteratorTypes);
406 tensorType.getElementType(), loc);
409 for (int64_t i = 0; i < tensorType.getRank(); ++i)
410 bbArgReplacements.push_back(rewriter.
create<linalg::IndexOp>(loc, i));
411 rewriter.
mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
414 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
418 rewriter.
replaceOp(generateOp, genericOp->getResult(0));
419 return genericOp.getOperation();
423 FailureOr<Operation *>
425 tensor::PadOp padOp) {
427 if (!padOp.getBodyRegion().hasOneBlock())
432 RankedTensorType resultType = padOp.getResultType();
436 padOp,
"failed to reify tensor.pad op result shape");
438 for (int64_t i = 0; i < resultType.getRank(); ++i)
439 if (resultType.isDynamicDim(i))
440 dynamicSizes.push_back(reifiedShape[0][i].get<Value>());
444 if (padOp.getNofoldAttr() &&
445 llvm::all_of(padOp.getMixedLowPad(),
isZeroIndex) &&
446 llvm::all_of(padOp.getMixedHighPad(),
isZeroIndex)) {
447 using bufferization::AllocTensorOp;
449 rewriter.
create<AllocTensorOp>(loc, resultType, dynamicSizes);
451 padOp, padOp.getSource(), allocated);
452 return copyOp.getOperation();
455 Value empty = rewriter.
create<EmptyOp>(loc, resultType, dynamicSizes);
466 padOp, padOp.getSource(), fillOp->getResult(0),
467 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
468 return insertSliceOp.getOperation();
474 using namespace bufferization;
477 if (
auto padOp = dyn_cast<tensor::PadOp>(op))
479 if (
auto maskOp = dyn_cast<vector::MaskOp>(op))
481 if (
auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
485 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
488 BufferizationOptions bufferizationOptions;
492 if (!
options.bufferizeDestinationOnly) {
499 [](
Value v) { return isa<TensorType>(v.getType()); }))
500 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
502 [](
Value v) { return isa<TensorType>(v.getType()); }))
503 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
511 if (!isa<TensorType>(result.getType()))
514 if (!isa<RankedTensorType>(result.getType()))
517 if (bufferizableOp.bufferizesToAllocation(result))
519 tensorResults.push_back(result);
525 auto addOutOfPlaceOperand = [&](
OpOperand *operand) {
526 if (!llvm::is_contained(outOfPlaceOperands, operand))
527 outOfPlaceOperands.push_back(operand);
529 for (
OpResult result : tensorResults) {
531 state.getAliasingOpOperands(result);
532 for (
const AliasingOpOperand &operand : aliasingOperands) {
533 addOutOfPlaceOperand(operand.opOperand);
534 for (
OpOperand &resultUse : result.getUses())
535 resultUses.push_back(&resultUse);
539 if (!state.bufferizesToMemoryWrite(operand))
541 if (!isa<RankedTensorType>(operand.get().getType()))
543 addOutOfPlaceOperand(&operand);
546 if (outOfPlaceOperands.size() != 1)
553 for (
OpOperand *operand : outOfPlaceOperands) {
555 rewriter, op->
getLoc(), operand->get(),
options, memorySpace);
556 allocs.push_back(alloc);
557 if (!state.findDefinitions(operand->get()).empty()) {
563 auto toTensorOp = rewriter.
create<ToTensorOp>(op->
getLoc(), alloc);
564 operand->set(toTensorOp);
565 if (
options.bufferizeDestinationOnly) {
567 toTensorOp.setRestrict(
true);
568 toTensorOp.setWritable(
true);
574 if (
options.bufferizeDestinationOnly)
575 return allocs.front();
579 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
584 for (
OpOperand *resultUse : resultUses) {
585 auto toTensorOp = resultUse->get().
getDefiningOp<ToTensorOp>();
586 assert(toTensorOp &&
"expected to_tensor op");
588 toTensorOp.setRestrict(
true);
589 toTensorOp.setWritable(
true);
592 return allocs.front();
597 template <
typename OpTy>
598 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
607 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
608 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
609 patterns.
add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Options for BufferizableOpInterface-based bufferization.
@ MaterializeInDestination