24 #define GEN_PASS_DEF_LINALGDETENSORIZEPASS
25 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 assert(inputs.size() == 1);
34 auto inputType = inputs[0].
getType();
35 if (isa<TensorType>(inputType))
40 return builder.
create<tensor::FromElementsOp>(
52 return tensorType.
hasRank() && tensorType.getRank() == 0;
56 GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
58 llvm::all_of(genericOp->getOpOperands(), [&](
OpOperand &opOperand) {
59 return !typeConverter.isLegal(opOperand.get().getType());
68 matchAndRewrite(GenericOp op, OpAdaptor adaptor,
83 rewriter.
replaceOp(op, yieldOp->getOperands());
86 rewriter.
mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
97 struct FunctionNonEntryBlockConversion
102 blockArgsToDetensor(std::move(blockArgsToDetensor)) {}
108 Region ®ion = op.getFunctionBody();
111 for (
Block &block : llvm::drop_begin(region, 1)) {
112 conversions.emplace_back(block.getNumArguments());
116 int idx = blockArgument.getArgNumber();
118 if (blockArgsToDetensor.count(blockArgument))
119 back.
addInputs(idx, {getTypeConverter()->convertType(
120 block.getArgumentTypes()[idx])});
122 back.
addInputs(idx, {block.getArgumentTypes()[idx]});
142 DetensorizeTypeConverter() {
143 addConversion([](
Type type) {
return type; });
148 if (canBeDetensored(tensorType))
166 struct LinalgDetensorize
167 :
public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
168 using impl::LinalgDetensorizePassBase<
169 LinalgDetensorize>::LinalgDetensorizePassBase;
170 LinalgDetensorize() =
default;
174 virtual ~CostModel() =
default;
208 virtual void compute(FunctionOpInterface func,
209 DetensorizeTypeConverter typeConverter,
224 for (
auto blockArgumentElem : blockArgsToDetensor) {
225 Block *block = blockArgumentElem.getOwner();
228 pred != block->
pred_end(); ++pred) {
229 BranchOpInterface terminator =
230 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
232 terminator.getSuccessorOperands(pred.getSuccessorIndex());
234 if (blockOperands.empty() ||
235 blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
238 detensorableBranchOps[terminator].insert(
239 blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
243 return detensorableBranchOps;
258 class ControlFlowDetectionModel :
public CostModel {
260 void compute(FunctionOpInterface func,
261 DetensorizeTypeConverter typeConverter,
266 func->walk([&](cf::CondBranchOp condBr) {
267 llvm::append_range(workList, condBr.getOperands());
270 func->walk([&](cf::BranchOp br) {
271 llvm::append_range(workList, br.getOperands());
280 auto updateWorkListWithSuccessorArguments =
281 [&](
Value value, BranchOpInterface terminator) {
285 for (
auto operandIdx :
286 llvm::seq<unsigned>(0, terminator->getOperands().size())) {
287 Value operand = terminator->getOperand(operandIdx);
289 if (operand == value) {
291 terminator.getSuccessorBlockArgument(operandIdx);
293 if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
294 workList.push_back(*succBlockArg);
299 while (!workList.empty()) {
300 Value currentItem = workList.pop_back_val();
302 if (!visitedValues.insert(currentItem).second)
308 updateWorkListWithSuccessorArguments(
309 currentItem, dyn_cast<BranchOpInterface>(
316 for (
auto *user : currentItem.
getUsers())
317 llvm::append_range(workList, user->getResults());
325 if (dyn_cast<BlockArgument>(currentItem)) {
327 cast<BlockArgument>(currentItem);
336 blockArgsToDetensor.insert(currentItemBlockArgument);
339 pred != ownerBlock->
pred_end(); ++pred) {
340 BranchOpInterface predTerminator =
341 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
345 if (!predTerminator) {
346 opsToDetensor.clear();
347 blockArgsToDetensor.clear();
351 auto ownerBlockOperands =
352 predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
354 if (ownerBlockOperands.empty() ||
355 ownerBlockOperands.isOperandProduced(
362 ownerBlockOperands[currentItemBlockArgument.
getArgNumber()]);
370 if (!visitedOps.insert(currentItemDefiningOp).second)
378 if (
auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
380 if (opsToDetensor.count(genericOp))
385 if (!shouldBeDetensored(genericOp, typeConverter)) {
389 opsToDetensor.insert(genericOp);
390 llvm::append_range(workList, genericOp.getInputs());
401 if (isa<tensor::FromElementsOp>(currentItemDefiningOp))
408 [&](
Type resultType) { return resultType.isIntOrFloat(); }))
409 llvm::append_range(workList, currentItemDefiningOp->
getOperands());
418 for (
auto &blockArg : blockArgsToDetensor) {
419 Block *block = blockArg.getParentBlock();
424 pred != block->
pred_end(); ++pred) {
425 BranchOpInterface terminator =
426 dyn_cast<BranchOpInterface>((*pred)->getTerminator());
428 terminator.getSuccessorOperands(pred.getSuccessorIndex());
430 if (blockOperands.empty() ||
431 blockOperands.isOperandProduced(blockArg.getArgNumber()))
435 blockOperands[blockArg.getArgNumber()].getDefiningOp();
439 if (isa_and_nonnull<GenericOp>(definingOp) &&
440 opsToDetensor.count(definingOp) == 0) {
441 blockArgsToRemove.insert(blockArg);
447 for (
auto &blockArg : blockArgsToRemove) {
448 blockArgsToDetensor.erase(blockArg);
454 class AggressiveDetensoringModel :
public CostModel {
456 void compute(FunctionOpInterface func,
457 DetensorizeTypeConverter typeConverter,
460 func->walk([&](GenericOp genericOp) {
461 if (shouldBeDetensored(genericOp, typeConverter))
462 opsToDetensor.insert(genericOp);
465 for (
Block &block : llvm::drop_begin(func.getFunctionBody(), 1))
467 blockArgsToDetensor.insert(blockArgument);
471 void runOnOperation()
override {
473 DetensorizeTypeConverter typeConverter;
479 FunctionOpInterface funcOp = getOperation();
481 if (funcOp.getFunctionBody().empty())
489 Block *entryBlock = &funcOp.getFunctionBody().
front();
490 Block *postEntryBlock =
496 if (aggressiveMode.getValue()) {
497 AggressiveDetensoringModel costModel;
498 costModel.compute(funcOp, typeConverter, opsToDetensor,
499 blockArgsToDetensor);
501 ControlFlowDetectionModel costModel;
502 costModel.compute(funcOp, typeConverter, opsToDetensor,
503 blockArgsToDetensor);
506 detensorableBranchOps =
507 CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
509 target.addDynamicallyLegalOp<GenericOp>(
510 [&](GenericOp op) {
return !opsToDetensor.count(op); });
512 target.markUnknownOpDynamicallyLegal([&](
Operation *op) {
518 if (
auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
519 Region &body = funcOp.getFunctionBody();
520 return llvm::all_of(llvm::drop_begin(body, 1), [&](
Block &block) {
521 return !llvm::any_of(
523 return blockArgument.
getOwner() == &block &&
524 !typeConverter.isLegal(blockArgument.
getType());
534 if (
auto branchOp = dyn_cast<BranchOpInterface>(op)) {
535 if (!detensorableBranchOps.count(branchOp))
538 for (
auto operandIdx : detensorableBranchOps[branchOp])
539 if (!typeConverter.isLegal(
540 branchOp->getOperand(operandIdx).getType()))
549 patterns.add<DetensorizeGenericOp>(typeConverter, context);
550 patterns.add<FunctionNonEntryBlockConversion>(context, typeConverter,
551 blockArgsToDetensor);
555 auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
556 int operandIdx) ->
bool {
557 return detensorableBranchOps.count(branchOp) &&
558 detensorableBranchOps[branchOp].count(operandIdx);
562 shouldConvertBranchOperand);
569 tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
571 std::move(canonPatterns))))
static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc)
static MLIRContext * getContext(OpFoldResult val)
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
pred_iterator pred_begin()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void startOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
void cancelOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
void finalizeOpModification(Operation *op) override
PatternRewriter hook for updating the given operation in-place.
LogicalResult convertNonEntryRegionTypes(Region *region, const TypeConverter &converter, ArrayRef< TypeConverter::SignatureConversion > blockConversions)
Convert the types of block arguments within the given region except for the entry region.
This class describes a specific conversion target.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
OpInterfaceConversionPattern is a wrapper around ConversionPattern that allows for matching and rewri...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Implement a predecessor iterator for blocks.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter, function_ref< bool(BranchOpInterface branchOp, int idx)> shouldConvertBranchOperand=nullptr)
Add a pattern to the given pattern list to rewrite branch operations to use operands that have been l...
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isNotBranchOpInterfaceOrReturnLikeOp(Operation *op)
Return true if op is neither BranchOpInterface nor ReturnLike.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isLegalForReturnOpTypeConversionPattern(Operation *op, TypeConverter &converter, bool returnOpAlwaysLegal=false)
For ReturnLike ops (except return), return True.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.