24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].
getType();
35 if (isa<TensorType>(inputType))
40 return builder.
create<tensor::FromElementsOp>(
52 return tensorType.
hasRank() && tensorType.getRank() == 0;
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
58 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
59 return !typeConverter.isLegal(opOperand.get().getType());
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
70 Block *originalBlock = op->getBlock();
73 Block *opEntryBlock = &*op.getRegion().
begin();
74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
83 rewriter.
replaceOp(op, yieldOp->getOperands());
86 rewriter.
mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
97 struct FunctionNonEntryBlockConversion
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
108 Region ®ion = op.getFunctionBody();
111 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
113 block.getNumArguments());
116 int idx = blockArgument.getArgNumber();
118 if (blockArgsToDetensor.count(blockArgument))
119 conversion.addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
122 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
138 DetensorizeTypeConverter() {
139 addConversion([](
Type type) {
return type; });
144 if (canBeDetensored(tensorType))
161 struct LinalgDetensorize
162 :
public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
163 using impl::LinalgDetensorizePassBase<
164 LinalgDetensorize>::LinalgDetensorizePassBase;
165 LinalgDetensorize() =
default;
169 virtual ~CostModel() =
default;
203 virtual void compute(FunctionOpInterface func,
204 DetensorizeTypeConverter typeConverter,
219 for (
auto blockArgumentElem : blockArgsToDetensor) {
220 Block *block = blockArgumentElem.getOwner();
223 pred != block->
pred_end(); ++pred) {
224 BranchOpInterface terminator =
225 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
227 terminator.getSuccessorOperands(pred.getSuccessorIndex());
229 if (blockOperands.empty() ||
230 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
233 detensorableBranchOps[terminator].insert(
234 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
238 return detensorableBranchOps;
253 class ControlFlowDetectionModel :
public CostModel {
255 void compute(FunctionOpInterface func,
256 DetensorizeTypeConverter typeConverter,
261 func->walk([&](cf::CondBranchOp condBr) {
262 llvm::append_range(workList, condBr.getOperands());
265 func->walk([&](cf::BranchOp br) {
266 llvm::append_range(workList, br.getOperands());
275 auto updateWorkListWithSuccessorArguments =
276 [&](
Value value, BranchOpInterface terminator) {
280 for (
auto operandIdx :
281 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282 Value operand = terminator->getOperand(operandIdx);
284 if (operand == value) {
286 terminator.getSuccessorBlockArgument(operandIdx);
288 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289 workList.push_back(*succBlockArg);
294 while (!workList.empty()) {
295 Value currentItem = workList.pop_back_val();
297 if (!visitedValues.insert(currentItem).second)
303 updateWorkListWithSuccessorArguments(
304 currentItem, dyn_cast<BranchOpInterface>(
311 for (
auto *user : currentItem.
getUsers())
312 llvm::append_range(workList, user->getResults());
320 if (dyn_cast<BlockArgument>(currentItem)) {
322 cast<BlockArgument>(currentItem);
331 blockArgsToDetensor.insert(currentItemBlockArgument);
334 pred != ownerBlock->
pred_end(); ++pred) {
335 BranchOpInterface predTerminator =
336 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
340 if (!predTerminator) {
341 opsToDetensor.clear();
342 blockArgsToDetensor.clear();
346 auto ownerBlockOperands =
347 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
349 if (ownerBlockOperands.empty() ||
350 ownerBlockOperands.isOperandProduced(
357 ownerBlockOperands[currentItemBlockArgument.
getArgNumber()]);
365 if (!visitedOps.insert(currentItemDefiningOp).second)
373 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
375 if (opsToDetensor.count(genericOp))
380 if (!shouldBeDetensored(genericOp, typeConverter)) {
384 opsToDetensor.insert(genericOp);
385 llvm::append_range(workList, genericOp.getInputs());
396 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
403 [&](
Type resultType) { return resultType.isIntOrFloat(); }))
404 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
413 for (
auto &blockArg : blockArgsToDetensor) {
414 Block *block = blockArg.getParentBlock();
419 pred != block->
pred_end(); ++pred) {
420 BranchOpInterface terminator =
421 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
423 terminator.getSuccessorOperands(pred.getSuccessorIndex());
425 if (blockOperands.empty() ||
426 blockOperands.isOperandProduced(blockArg.getArgNumber()))
430 blockOperands[blockArg.getArgNumber()].getDefiningOp();
434 if (isa_and_nonnull<GenericOp>(definingOp) &&
435 opsToDetensor.count(definingOp) == 0) {
436 blockArgsToRemove.insert(blockArg);
442 for (
auto &blockArg : blockArgsToRemove) {
443 blockArgsToDetensor.erase(blockArg);
449 class AggressiveDetensoringModel :
public CostModel {
451 void compute(FunctionOpInterface func,
452 DetensorizeTypeConverter typeConverter,
455 func->walk([&](GenericOp genericOp) {
456 if (shouldBeDetensored(genericOp, typeConverter))
457 opsToDetensor.insert(genericOp);
460 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
462 blockArgsToDetensor.insert(blockArgument);
466 void runOnOperation()
override {
468 DetensorizeTypeConverter typeConverter;
474 FunctionOpInterface funcOp = getOperation();
476 if (funcOp.getFunctionBody().empty())
484 Block *entryBlock = &funcOp.getFunctionBody().
front();
485 Block *postEntryBlock =
491 if (aggressiveMode.getValue()) {
492 AggressiveDetensoringModel costModel;
493 costModel.compute(funcOp, typeConverter, opsToDetensor,
494 blockArgsToDetensor);
496 ControlFlowDetectionModel costModel;
497 costModel.compute(funcOp, typeConverter, opsToDetensor,
498 blockArgsToDetensor);
501 detensorableBranchOps =
502 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
504 target.addDynamicallyLegalOp<GenericOp>(
505 [&](GenericOp op) {
return !opsToDetensor.count(op); });
507 target.markUnknownOpDynamicallyLegal([&](
Operation *op) {
513 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
514 Region &body = funcOp.getFunctionBody();
515 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
516 return !llvm::any_of(
518 return blockArgument.
getOwner() == &block &&
519 !typeConverter.isLegal(blockArgument.
getType());
529 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
530 if (!detensorableBranchOps.count(branchOp))
533 for (
auto operandIdx : detensorableBranchOps[branchOp])
534 if (!typeConverter.isLegal(
535 branchOp->getOperand(operandIdx).getType()))
544 patterns.add<DetensorizeGenericOp>(typeConverter, context);
545 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
546 blockArgsToDetensor);
550 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
551 int operandIdx) ->
bool {
552 return detensorableBranchOps.count(branchOp) &&
553 detensorableBranchOps[branchOp].count(operandIdx);
557 shouldConvertBranchOperand);
564 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
const FrozenRewritePatternSet & patterns
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...