21 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
22 #include "mlir/Dialect/Linalg/Passes.h.inc"
30 assert(inputs.size() == 1);
31 auto inputType = inputs[0].
getType();
32 if (isa<TensorType>(inputType))
37 return builder.
create<tensor::FromElementsOp>(
49 return tensorType.
hasRank() && tensorType.getRank() == 0;
53 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
55 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
56 return !typeConverter.isLegal(opOperand.get().getType());
65 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
67 Block *originalBlock = op->getBlock();
70 Block *opEntryBlock = &*op.getRegion().
begin();
71 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
80 rewriter.
replaceOp(op, yieldOp->getOperands());
83 rewriter.
mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
94 struct FunctionNonEntryBlockConversion
99 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
105 Region ®ion = op.getFunctionBody();
108 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
110 block.getNumArguments());
113 int idx = blockArgument.getArgNumber();
115 if (blockArgsToDetensor.count(blockArgument))
116 conversion.addInputs(idx, {getTypeConverter()->convertType(
117 block.getArgumentTypes()[idx])});
119 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
135 DetensorizeTypeConverter() {
136 addConversion([](
Type type) {
return type; });
141 if (canBeDetensored(tensorType))
158 struct LinalgDetensorize
159 :
public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
160 using impl::LinalgDetensorizePassBase<
161 LinalgDetensorize>::LinalgDetensorizePassBase;
162 LinalgDetensorize() =
default;
166 virtual ~CostModel() =
default;
200 virtual void compute(FunctionOpInterface func,
201 DetensorizeTypeConverter typeConverter,
216 for (
auto blockArgumentElem : blockArgsToDetensor) {
217 Block *block = blockArgumentElem.getOwner();
220 pred != block->
pred_end(); ++pred) {
221 BranchOpInterface terminator =
222 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
224 terminator.getSuccessorOperands(pred.getSuccessorIndex());
226 if (blockOperands.empty() ||
227 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
230 detensorableBranchOps[terminator].insert(
231 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
235 return detensorableBranchOps;
250 class ControlFlowDetectionModel :
public CostModel {
252 void compute(FunctionOpInterface func,
253 DetensorizeTypeConverter typeConverter,
258 func->walk([&](cf::CondBranchOp condBr) {
259 llvm::append_range(workList, condBr.getOperands());
262 func->walk([&](cf::BranchOp br) {
263 llvm::append_range(workList, br.getOperands());
272 auto updateWorkListWithSuccessorArguments =
273 [&](
Value value, BranchOpInterface terminator) {
277 for (
auto operandIdx :
278 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
279 Value operand = terminator->getOperand(operandIdx);
281 if (operand == value) {
283 terminator.getSuccessorBlockArgument(operandIdx);
285 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
286 workList.push_back(*succBlockArg);
291 while (!workList.empty()) {
292 Value currentItem = workList.pop_back_val();
294 if (!visitedValues.insert(currentItem).second)
300 updateWorkListWithSuccessorArguments(
301 currentItem, dyn_cast<BranchOpInterface>(
308 for (
auto *user : currentItem.
getUsers())
309 llvm::append_range(workList, user->getResults());
317 if (
auto currentItemBlockArgument =
318 dyn_cast<BlockArgument>(currentItem)) {
319 Block *ownerBlock = currentItemBlockArgument.getOwner();
327 blockArgsToDetensor.insert(currentItemBlockArgument);
330 pred != ownerBlock->
pred_end(); ++pred) {
331 BranchOpInterface predTerminator =
332 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
336 if (!predTerminator) {
337 opsToDetensor.clear();
338 blockArgsToDetensor.clear();
342 auto ownerBlockOperands =
343 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
345 if (ownerBlockOperands.empty() ||
346 ownerBlockOperands.isOperandProduced(
347 currentItemBlockArgument.getArgNumber()))
353 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
361 if (!visitedOps.insert(currentItemDefiningOp).second)
369 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
371 if (opsToDetensor.count(genericOp))
376 if (!shouldBeDetensored(genericOp, typeConverter)) {
380 opsToDetensor.insert(genericOp);
381 llvm::append_range(workList, genericOp.getInputs());
392 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
399 [&](
Type resultType) { return resultType.isIntOrFloat(); }))
400 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
409 for (
auto &blockArg : blockArgsToDetensor) {
410 Block *block = blockArg.getParentBlock();
415 pred != block->
pred_end(); ++pred) {
416 BranchOpInterface terminator =
417 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
419 terminator.getSuccessorOperands(pred.getSuccessorIndex());
421 if (blockOperands.empty() ||
422 blockOperands.isOperandProduced(blockArg.getArgNumber()))
426 blockOperands[blockArg.getArgNumber()].getDefiningOp();
430 if (isa_and_nonnull<GenericOp>(definingOp) &&
431 opsToDetensor.count(definingOp) == 0) {
432 blockArgsToRemove.insert(blockArg);
438 for (
auto &blockArg : blockArgsToRemove) {
439 blockArgsToDetensor.erase(blockArg);
445 class AggressiveDetensoringModel :
public CostModel {
447 void compute(FunctionOpInterface func,
448 DetensorizeTypeConverter typeConverter,
451 func->walk([&](GenericOp genericOp) {
452 if (shouldBeDetensored(genericOp, typeConverter))
453 opsToDetensor.insert(genericOp);
456 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
461 void runOnOperation()
override {
463 DetensorizeTypeConverter typeConverter;
469 FunctionOpInterface funcOp = getOperation();
471 if (funcOp.getFunctionBody().empty())
479 Block *entryBlock = &funcOp.getFunctionBody().
front();
480 Block *postEntryBlock =
486 if (aggressiveMode.getValue()) {
487 AggressiveDetensoringModel costModel;
488 costModel.compute(funcOp, typeConverter, opsToDetensor,
489 blockArgsToDetensor);
491 ControlFlowDetectionModel costModel;
492 costModel.compute(funcOp, typeConverter, opsToDetensor,
493 blockArgsToDetensor);
496 detensorableBranchOps =
497 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
499 target.addDynamicallyLegalOp<GenericOp>(
500 [&](GenericOp op) {
return !opsToDetensor.count(op); });
502 target.markUnknownOpDynamicallyLegal([&](
Operation *op) {
508 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
509 Region &body = funcOp.getFunctionBody();
510 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
511 return !llvm::any_of(
513 return blockArgument.
getOwner() == &block &&
514 !typeConverter.isLegal(blockArgument.
getType());
524 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
525 if (!detensorableBranchOps.count(branchOp))
528 for (
auto operandIdx : detensorableBranchOps[branchOp])
529 if (!typeConverter.isLegal(
530 branchOp->getOperand(operandIdx).getType()))
539 patterns.add<DetensorizeGenericOp>(typeConverter, context);
540 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
541 blockArgsToDetensor);
545 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
546 int operandIdx) ->
bool {
547 return detensorableBranchOps.count(branchOp) &&
548 detensorableBranchOps[branchOp].count(operandIdx);
552 shouldConvertBranchOperand);
559 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
const FrozenRewritePatternSet & patterns
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...