21#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
22#include "mlir/Dialect/Linalg/Passes.h.inc"
30 assert(inputs.size() == 1);
31 auto inputType = inputs[0].
getType();
32 if (isa<TensorType>(inputType))
37 return tensor::FromElementsOp::create(
38 builder, loc, RankedTensorType::get({}, inputType), inputs[0]);
49 return tensorType.
hasRank() && tensorType.getRank() == 0;
53 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
55 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
56 return !typeConverter.isLegal(opOperand.get().getType());
61class DetensorizeGenericOp :
public OpConversionPattern<GenericOp> {
63 using OpConversionPattern::OpConversionPattern;
65 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter)
const override {
67 Block *originalBlock = op->getBlock();
70 Block *opEntryBlock = &*op.getRegion().begin();
71 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
76 rewriter.inlineRegionBefore(op.getRegion(), newBlock);
80 rewriter.replaceOp(op, yieldOp->getOperands());
83 rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
84 rewriter.mergeBlocks(newBlock, originalBlock, {});
94struct FunctionNonEntryBlockConversion
95 :
public OpInterfaceConversionPattern<FunctionOpInterface> {
96 FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
98 : OpInterfaceConversionPattern(converter, ctx),
99 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
102 matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
103 ConversionPatternRewriter &rewriter)
const override {
104 rewriter.startOpModification(op);
105 Region ®ion = op.getFunctionBody();
108 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
109 TypeConverter::SignatureConversion conversion(
110 block.getNumArguments());
112 for (BlockArgument blockArgument : block.getArguments()) {
113 int idx = blockArgument.getArgNumber();
115 if (blockArgsToDetensor.count(blockArgument))
116 conversion.addInputs(idx, {getTypeConverter()->convertType(
117 block.getArgumentTypes()[idx])});
119 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
122 rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
125 rewriter.finalizeOpModification(op);
135 DetensorizeTypeConverter() {
136 addConversion([](Type type) {
return type; });
140 addConversion([](TensorType tensorType) -> Type {
141 if (canBeDetensored(tensorType))
148 addTargetMaterialization([](OpBuilder &builder, Type type,
150 return tensor::ExtractOp::create(builder, loc, inputs[0],
ValueRange{});
158struct LinalgDetensorize
160 using impl::LinalgDetensorizePassBase<
161 LinalgDetensorize>::LinalgDetensorizePassBase;
162 LinalgDetensorize() =
default;
166 virtual ~CostModel() =
default;
200 virtual void compute(FunctionOpInterface func,
201 DetensorizeTypeConverter typeConverter,
216 for (
auto blockArgumentElem : blockArgsToDetensor) {
217 Block *block = blockArgumentElem.getOwner();
219 for (PredecessorIterator pred = block->
pred_begin();
220 pred != block->
pred_end(); ++pred) {
221 BranchOpInterface terminator =
222 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
224 terminator.getSuccessorOperands(pred.getSuccessorIndex());
226 if (blockOperands.empty() ||
227 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
230 detensorableBranchOps[terminator].insert(
231 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
235 return detensorableBranchOps;
250 class ControlFlowDetectionModel :
public CostModel {
252 void compute(FunctionOpInterface func,
253 DetensorizeTypeConverter typeConverter,
256 SmallVector<Value> workList;
258 func->walk([&](cf::CondBranchOp condBr) {
259 llvm::append_range(workList, condBr.getOperands());
262 func->walk([&](cf::BranchOp br) {
263 llvm::append_range(workList, br.getOperands());
272 auto updateWorkListWithSuccessorArguments =
273 [&](Value value, BranchOpInterface terminator) {
277 for (
auto operandIdx :
278 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
279 Value operand = terminator->getOperand(operandIdx);
281 if (operand == value) {
283 terminator.getSuccessorBlockArgument(operandIdx);
285 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
286 workList.push_back(*succBlockArg);
291 while (!workList.empty()) {
292 Value currentItem = workList.pop_back_val();
294 if (!visitedValues.insert(currentItem).second)
300 updateWorkListWithSuccessorArguments(
301 currentItem, dyn_cast<BranchOpInterface>(
308 for (
auto *user : currentItem.
getUsers())
309 llvm::append_range(workList, user->getResults());
317 if (
auto currentItemBlockArgument =
318 dyn_cast<BlockArgument>(currentItem)) {
319 Block *ownerBlock = currentItemBlockArgument.getOwner();
327 blockArgsToDetensor.insert(currentItemBlockArgument);
329 for (PredecessorIterator pred = ownerBlock->
pred_begin();
330 pred != ownerBlock->
pred_end(); ++pred) {
331 BranchOpInterface predTerminator =
332 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
336 if (!predTerminator) {
337 opsToDetensor.clear();
338 blockArgsToDetensor.clear();
342 auto ownerBlockOperands =
343 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345 if (ownerBlockOperands.empty() ||
346 ownerBlockOperands.isOperandProduced(
347 currentItemBlockArgument.getArgNumber()))
353 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
359 Operation *currentItemDefiningOp = currentItem.
getDefiningOp();
361 if (!visitedOps.insert(currentItemDefiningOp).second)
369 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371 if (opsToDetensor.count(genericOp))
376 if (!shouldBeDetensored(genericOp, typeConverter)) {
380 opsToDetensor.insert(genericOp);
381 llvm::append_range(workList, genericOp.getInputs());
392 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
399 [&](Type resultType) { return resultType.isIntOrFloat(); }))
400 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
409 for (
auto &blockArg : blockArgsToDetensor) {
410 Block *block = blockArg.getParentBlock();
414 for (PredecessorIterator pred = block->
pred_begin();
415 pred != block->
pred_end(); ++pred) {
416 BranchOpInterface terminator =
417 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419 terminator.getSuccessorOperands(pred.getSuccessorIndex());
421 if (blockOperands.empty() ||
422 blockOperands.isOperandProduced(blockArg.getArgNumber()))
425 Operation *definingOp =
426 blockOperands[blockArg.getArgNumber()].getDefiningOp();
430 if (isa_and_nonnull<GenericOp>(definingOp) &&
431 opsToDetensor.count(definingOp) == 0) {
432 blockArgsToRemove.insert(blockArg);
438 for (
auto &blockArg : blockArgsToRemove) {
439 blockArgsToDetensor.erase(blockArg);
445 class AggressiveDetensoringModel :
public CostModel {
447 void compute(FunctionOpInterface func,
448 DetensorizeTypeConverter typeConverter,
451 func->walk([&](GenericOp genericOp) {
452 if (shouldBeDetensored(genericOp, typeConverter))
453 opsToDetensor.insert(genericOp);
456 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
461 void runOnOperation()
override {
463 DetensorizeTypeConverter typeConverter;
464 RewritePatternSet
patterns(context);
465 ConversionTarget
target(*context);
469 FunctionOpInterface funcOp = getOperation();
471 if (funcOp.getFunctionBody().empty())
481 rewriter.splitBlock(entryBlock, entryBlock->
begin());
482 rewriter.setInsertionPointToStart(entryBlock);
483 auto branch = cf::BranchOp::create(rewriter, rewriter.getUnknownLoc(),
487 AggressiveDetensoringModel costModel;
488 costModel.compute(funcOp, typeConverter, opsToDetensor,
489 blockArgsToDetensor);
491 ControlFlowDetectionModel costModel;
492 costModel.compute(funcOp, typeConverter, opsToDetensor,
493 blockArgsToDetensor);
496 detensorableBranchOps =
497 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
499 target.addDynamicallyLegalOp<GenericOp>(
500 [&](GenericOp op) {
return !opsToDetensor.count(op); });
508 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
509 Region &body = funcOp.getFunctionBody();
510 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
511 return !llvm::any_of(
513 return blockArgument.
getOwner() == &block &&
514 !typeConverter.isLegal(blockArgument.
getType());
524 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
525 if (!detensorableBranchOps.count(branchOp))
528 for (
auto operandIdx : detensorableBranchOps[branchOp])
529 if (!typeConverter.isLegal(
530 branchOp->getOperand(operandIdx).getType()))
539 patterns.add<DetensorizeGenericOp>(typeConverter, context);
540 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
541 blockArgsToDetensor);
545 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
546 int operandIdx) ->
bool {
547 return detensorableBranchOps.count(branchOp) &&
548 detensorableBranchOps[branchOp].count(operandIdx);
552 shouldConvertBranchOperand);
555 applyFullConversion(getOperation(),
target, std::move(
patterns))))
558 RewritePatternSet canonPatterns(context);
559 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
565 rewriter.eraseOp(branch);
566 rewriter.mergeBlocks(postEntryBlock, entryBlock);
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
::mlir::Pass::Option< bool > aggressiveMode
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap