24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].
getType();
35 if (isa<TensorType>(inputType))
40 return builder.
create<tensor::FromElementsOp>(
52 return tensorType.
hasRank() && tensorType.getRank() == 0;
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
58 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
59 return !typeConverter.isLegal(opOperand.get().getType());
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
70 Block *originalBlock = op->getBlock();
73 Block *opEntryBlock = &*op.getRegion().
begin();
74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
83 rewriter.
replaceOp(op, yieldOp->getOperands());
86 rewriter.
mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
97 struct FunctionNonEntryBlockConversion
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
108 Region ®ion = op.getFunctionBody();
111 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
113 block.getNumArguments());
116 int idx = blockArgument.getArgNumber();
118 if (blockArgsToDetensor.count(blockArgument))
119 conversion.addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
122 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
138 DetensorizeTypeConverter() {
139 addConversion([](
Type type) {
return type; });
144 if (canBeDetensored(tensorType))
162 struct LinalgDetensorize
163 :
public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
164 using impl::LinalgDetensorizePassBase<
165 LinalgDetensorize>::LinalgDetensorizePassBase;
166 LinalgDetensorize() =
default;
170 virtual ~CostModel() =
default;
204 virtual void compute(FunctionOpInterface func,
205 DetensorizeTypeConverter typeConverter,
220 for (
auto blockArgumentElem : blockArgsToDetensor) {
221 Block *block = blockArgumentElem.getOwner();
224 pred != block->
pred_end(); ++pred) {
225 BranchOpInterface terminator =
226 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
228 terminator.getSuccessorOperands(pred.getSuccessorIndex());
230 if (blockOperands.empty() ||
231 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
234 detensorableBranchOps[terminator].insert(
235 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
239 return detensorableBranchOps;
254 class ControlFlowDetectionModel :
public CostModel {
256 void compute(FunctionOpInterface func,
257 DetensorizeTypeConverter typeConverter,
262 func->walk([&](cf::CondBranchOp condBr) {
263 llvm::append_range(workList, condBr.getOperands());
266 func->walk([&](cf::BranchOp br) {
267 llvm::append_range(workList, br.getOperands());
276 auto updateWorkListWithSuccessorArguments =
277 [&](
Value value, BranchOpInterface terminator) {
281 for (
auto operandIdx :
282 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
283 Value operand = terminator->getOperand(operandIdx);
285 if (operand == value) {
287 terminator.getSuccessorBlockArgument(operandIdx);
289 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
290 workList.push_back(*succBlockArg);
295 while (!workList.empty()) {
296 Value currentItem = workList.pop_back_val();
298 if (!visitedValues.insert(currentItem).second)
304 updateWorkListWithSuccessorArguments(
305 currentItem, dyn_cast<BranchOpInterface>(
312 for (
auto *user : currentItem.
getUsers())
313 llvm::append_range(workList, user->getResults());
321 if (dyn_cast<BlockArgument>(currentItem)) {
323 cast<BlockArgument>(currentItem);
332 blockArgsToDetensor.insert(currentItemBlockArgument);
335 pred != ownerBlock->
pred_end(); ++pred) {
336 BranchOpInterface predTerminator =
337 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
341 if (!predTerminator) {
342 opsToDetensor.clear();
343 blockArgsToDetensor.clear();
347 auto ownerBlockOperands =
348 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
350 if (ownerBlockOperands.empty() ||
351 ownerBlockOperands.isOperandProduced(
358 ownerBlockOperands[currentItemBlockArgument.
getArgNumber()]);
366 if (!visitedOps.insert(currentItemDefiningOp).second)
374 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
376 if (opsToDetensor.count(genericOp))
381 if (!shouldBeDetensored(genericOp, typeConverter)) {
385 opsToDetensor.insert(genericOp);
386 llvm::append_range(workList, genericOp.getInputs());
397 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
404 [&](
Type resultType) { return resultType.isIntOrFloat(); }))
405 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
414 for (
auto &blockArg : blockArgsToDetensor) {
415 Block *block = blockArg.getParentBlock();
420 pred != block->
pred_end(); ++pred) {
421 BranchOpInterface terminator =
422 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
424 terminator.getSuccessorOperands(pred.getSuccessorIndex());
426 if (blockOperands.empty() ||
427 blockOperands.isOperandProduced(blockArg.getArgNumber()))
431 blockOperands[blockArg.getArgNumber()].getDefiningOp();
435 if (isa_and_nonnull<GenericOp>(definingOp) &&
436 opsToDetensor.count(definingOp) == 0) {
437 blockArgsToRemove.insert(blockArg);
443 for (
auto &blockArg : blockArgsToRemove) {
444 blockArgsToDetensor.erase(blockArg);
450 class AggressiveDetensoringModel :
public CostModel {
452 void compute(FunctionOpInterface func,
453 DetensorizeTypeConverter typeConverter,
456 func->walk([&](GenericOp genericOp) {
457 if (shouldBeDetensored(genericOp, typeConverter))
458 opsToDetensor.insert(genericOp);
461 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
463 blockArgsToDetensor.insert(blockArgument);
467 void runOnOperation()
override {
469 DetensorizeTypeConverter typeConverter;
475 FunctionOpInterface funcOp = getOperation();
477 if (funcOp.getFunctionBody().empty())
485 Block *entryBlock = &funcOp.getFunctionBody().
front();
486 Block *postEntryBlock =
492 if (aggressiveMode.getValue()) {
493 AggressiveDetensoringModel costModel;
494 costModel.compute(funcOp, typeConverter, opsToDetensor,
495 blockArgsToDetensor);
497 ControlFlowDetectionModel costModel;
498 costModel.compute(funcOp, typeConverter, opsToDetensor,
499 blockArgsToDetensor);
502 detensorableBranchOps =
503 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
505 target.addDynamicallyLegalOp<GenericOp>(
506 [&](GenericOp op) {
return !opsToDetensor.count(op); });
508 target.markUnknownOpDynamicallyLegal([&](
Operation *op) {
514 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
515 Region &body = funcOp.getFunctionBody();
516 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
517 return !llvm::any_of(
519 return blockArgument.
getOwner() == &block &&
520 !typeConverter.isLegal(blockArgument.
getType());
530 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
531 if (!detensorableBranchOps.count(branchOp))
534 for (
auto operandIdx : detensorableBranchOps[branchOp])
535 if (!typeConverter.isLegal(
536 branchOp->getOperand(operandIdx).getType()))
545 patterns.add<DetensorizeGenericOp>(typeConverter, context);
546 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
547 blockArgsToDetensor);
551 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
552 int operandIdx) ->
bool {
553 return detensorableBranchOps.count(branchOp) &&
554 detensorableBranchOps[branchOp].count(operandIdx);
558 shouldConvertBranchOperand);
565 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
567 std::move(canonPatterns))))
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...