24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].
getType();
35 if (isa<TensorType>(inputType))
40 return builder.
create<tensor::FromElementsOp>(
52 return tensorType.
hasRank() && tensorType.getRank() == 0;
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
58 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
59 return !typeConverter.isLegal(opOperand.get().getType());
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
70 Block *originalBlock = op->getBlock();
73 Block *opEntryBlock = &*op.getRegion().
begin();
74 YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
83 rewriter.
replaceOp(op, yieldOp->getOperands());
86 rewriter.
mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
97 struct FunctionNonEntryBlockConversion
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
108 Region ®ion = op.getFunctionBody();
111 llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
113 block.getNumArguments());
116 int idx = blockArgument.getArgNumber();
118 if (blockArgsToDetensor.count(blockArgument))
119 conversion.addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
122 conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
138 DetensorizeTypeConverter() {
139 addConversion([](
Type type) {
return type; });
144 if (canBeDetensored(tensorType))
161 struct LinalgDetensorize
162 :
public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
163 using impl::LinalgDetensorizePassBase<
164 LinalgDetensorize>::LinalgDetensorizePassBase;
165 LinalgDetensorize() =
default;
169 virtual ~CostModel() =
default;
203 virtual void compute(FunctionOpInterface func,
204 DetensorizeTypeConverter typeConverter,
219 for (
auto blockArgumentElem : blockArgsToDetensor) {
220 Block *block = blockArgumentElem.getOwner();
223 pred != block->
pred_end(); ++pred) {
224 BranchOpInterface terminator =
225 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
227 terminator.getSuccessorOperands(pred.getSuccessorIndex());
229 if (blockOperands.empty() ||
230 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
233 detensorableBranchOps[terminator].insert(
234 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
238 return detensorableBranchOps;
253 class ControlFlowDetectionModel :
public CostModel {
255 void compute(FunctionOpInterface func,
256 DetensorizeTypeConverter typeConverter,
261 func->walk([&](cf::CondBranchOp condBr) {
262 llvm::append_range(workList, condBr.getOperands());
265 func->walk([&](cf::BranchOp br) {
266 llvm::append_range(workList, br.getOperands());
275 auto updateWorkListWithSuccessorArguments =
276 [&](
Value value, BranchOpInterface terminator) {
280 for (
auto operandIdx :
281 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
282 Value operand = terminator->getOperand(operandIdx);
284 if (operand == value) {
286 terminator.getSuccessorBlockArgument(operandIdx);
288 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
289 workList.push_back(*succBlockArg);
294 while (!workList.empty()) {
295 Value currentItem = workList.pop_back_val();
297 if (!visitedValues.insert(currentItem).second)
303 updateWorkListWithSuccessorArguments(
304 currentItem, dyn_cast<BranchOpInterface>(
311 for (
auto *user : currentItem.
getUsers())
312 llvm::append_range(workList, user->getResults());
320 if (
auto currentItemBlockArgument =
321 dyn_cast<BlockArgument>(currentItem)) {
322 Block *ownerBlock = currentItemBlockArgument.getOwner();
330 blockArgsToDetensor.insert(currentItemBlockArgument);
333 pred != ownerBlock->
pred_end(); ++pred) {
334 BranchOpInterface predTerminator =
335 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
339 if (!predTerminator) {
340 opsToDetensor.clear();
341 blockArgsToDetensor.clear();
345 auto ownerBlockOperands =
346 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
348 if (ownerBlockOperands.empty() ||
349 ownerBlockOperands.isOperandProduced(
350 currentItemBlockArgument.getArgNumber()))
356 ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
364 if (!visitedOps.insert(currentItemDefiningOp).second)
372 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
374 if (opsToDetensor.count(genericOp))
379 if (!shouldBeDetensored(genericOp, typeConverter)) {
383 opsToDetensor.insert(genericOp);
384 llvm::append_range(workList, genericOp.getInputs());
395 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
402 [&](
Type resultType) { return resultType.isIntOrFloat(); }))
403 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
412 for (
auto &blockArg : blockArgsToDetensor) {
413 Block *block = blockArg.getParentBlock();
418 pred != block->
pred_end(); ++pred) {
419 BranchOpInterface terminator =
420 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
422 terminator.getSuccessorOperands(pred.getSuccessorIndex());
424 if (blockOperands.empty() ||
425 blockOperands.isOperandProduced(blockArg.getArgNumber()))
429 blockOperands[blockArg.getArgNumber()].getDefiningOp();
433 if (isa_and_nonnull<GenericOp>(definingOp) &&
434 opsToDetensor.count(definingOp) == 0) {
435 blockArgsToRemove.insert(blockArg);
441 for (
auto &blockArg : blockArgsToRemove) {
442 blockArgsToDetensor.erase(blockArg);
448 class AggressiveDetensoringModel :
public CostModel {
450 void compute(FunctionOpInterface func,
451 DetensorizeTypeConverter typeConverter,
454 func->walk([&](GenericOp genericOp) {
455 if (shouldBeDetensored(genericOp, typeConverter))
456 opsToDetensor.insert(genericOp);
459 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
464 void runOnOperation()
override {
466 DetensorizeTypeConverter typeConverter;
472 FunctionOpInterface funcOp = getOperation();
474 if (funcOp.getFunctionBody().empty())
482 Block *entryBlock = &funcOp.getFunctionBody().
front();
483 Block *postEntryBlock =
489 if (aggressiveMode.getValue()) {
490 AggressiveDetensoringModel costModel;
491 costModel.compute(funcOp, typeConverter, opsToDetensor,
492 blockArgsToDetensor);
494 ControlFlowDetectionModel costModel;
495 costModel.compute(funcOp, typeConverter, opsToDetensor,
496 blockArgsToDetensor);
499 detensorableBranchOps =
500 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
502 target.addDynamicallyLegalOp<GenericOp>(
503 [&](GenericOp op) {
return !opsToDetensor.count(op); });
505 target.markUnknownOpDynamicallyLegal([&](
Operation *op) {
511 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
512 Region &body = funcOp.getFunctionBody();
513 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
514 return !llvm::any_of(
516 return blockArgument.
getOwner() == &block &&
517 !typeConverter.isLegal(blockArgument.
getType());
527 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
528 if (!detensorableBranchOps.count(branchOp))
531 for (
auto operandIdx : detensorableBranchOps[branchOp])
532 if (!typeConverter.isLegal(
533 branchOp->getOperand(operandIdx).getType()))
542 patterns.add<DetensorizeGenericOp>(typeConverter, context);
543 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
544 blockArgsToDetensor);
548 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
549 int operandIdx) ->
bool {
550 return detensorableBranchOps.count(branchOp) &&
551 detensorableBranchOps[branchOp].count(operandIdx);
555 shouldConvertBranchOperand);
562 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Block * applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter=nullptr)
Apply a signature conversion to given block.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
bool isLegalForReturnOpTypeConversionPattern(Operation *op, const TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
const FrozenRewritePatternSet & patterns
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...