MLIR  19.0.0git
Macros | Functions
LinalgTransformOps.cpp File Reference
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "linalg-transforms"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
#define DBGSNL()   (llvm::dbgs() << "\n")
 
#define LDBG(X)   LLVM_DEBUG(DBGS() << (X) << "\n")
 
#define DOWNSCALE(trans)
 
#define DOWNSCALE_CALL(a, b)   DownscaleSizeOneWindowed2DConvolution<a, b>
 
#define DOWNSCALE_NORMAL(a, b)   DOWNSCALE(DOWNSCALE_CALL(a, b))
 
#define GET_OP_CLASSES
 

Functions

template<typename PatternTy , typename... Args>
static FailureOr< LinalgOp > tryApply (Operation *operation, Args &&...args)
 Attempts to apply the pattern specified as template argument to the given operation. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
 Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle)
 
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
 When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More...
 
template<typename Range >
static LogicalResult applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
 Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More...
 
static OperationreplaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
 Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More...
 
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 Find the first "extract" user of producerOp and tile it right before its use. More...
 
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More...
 
static OperationcloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 
static void printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
 
static ParseResult parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
 
template<typename RelayoutOpTy >
bool isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
 Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More...
 
template<typename OpTy >
DiagnosedSilenceableFailure doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

Definition at line 54 of file LinalgTransformOps.cpp.

◆ DBGSNL

#define DBGSNL ( )    (llvm::dbgs() << "\n")

Definition at line 55 of file LinalgTransformOps.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "linalg-transforms"

Definition at line 53 of file LinalgTransformOps.cpp.

◆ DOWNSCALE

#define DOWNSCALE (   trans)
Value:
{ \
FailureOr<LinalgOp> res = tryApply<trans>(target); \
if (succeeded(res)) { \
results.push_back(*res); \
} \
}
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56

◆ DOWNSCALE_CALL

#define DOWNSCALE_CALL (   a,
 
)    DownscaleSizeOneWindowed2DConvolution<a, b>

◆ DOWNSCALE_NORMAL

#define DOWNSCALE_NORMAL (   a,
 
)    DOWNSCALE(DOWNSCALE_CALL(a, b))

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 3485 of file LinalgTransformOps.cpp.

◆ LDBG

#define LDBG (   X)    LLVM_DEBUG(DBGS() << (X) << "\n")

Definition at line 56 of file LinalgTransformOps.cpp.

Function Documentation

◆ applyTilingToAll()

template<typename Range >
static LogicalResult applyTilingToAll ( RewriterBase rewriter,
Operation transformOp,
Range &&  payloadOps,
unsigned  numLoops,
transform::TransformResults transformResults,
function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)>  applyFn 
)
static

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.

Definition at line 497 of file LinalgTransformOps.cpp.

References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::failure(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), mlir::OpBuilder::setInsertionPoint(), and mlir::success().

◆ cloneAndFuseFirstUse()

static Operation* cloneAndFuseFirstUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

◆ doit()

template<typename OpTy >
DiagnosedSilenceableFailure doit ( RewriterBase rewriter,
OpTy  target,
transform::ApplyToEachResultList results,
transform::TransformState state 
)

◆ isValidPackingPermutation()

template<typename RelayoutOpTy >
bool isValidPackingPermutation ( RelayoutOpTy  op,
ArrayRef< int64_t >  permutation,
OuterOrInnerPerm  outerOrInnerPerm = OuterOrInnerPerm::Outer 
)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op.

This is the case when thepermutationrank matches the rank expected by opandpermutationis itself a permutation vector. Return true if eitheroporpermutation` are empty to allow a simpler polymorphic implementation.

Definition at line 1591 of file LinalgTransformOps.cpp.

◆ parseMultitileSizesTypes()

static ParseResult parseMultitileSizesTypes ( OpAsmParser parser,
Type targetType,
Type lowSizeType,
Type highSizeType,
Type splitPointType 
)
static

◆ printMultitileSizesTypes()

static void printMultitileSizesTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  lowSizeType,
Type  ,
Type   
)
static

Definition at line 1311 of file LinalgTransformOps.cpp.

◆ reifyMixedParamAndHandleResults()

static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults ( TransformState state,
TransformOpInterface &  transformOp,
ArrayRef< OpFoldResult mixedResults,
SmallVectorImpl< int64_t > &  reified 
)
static

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred.

If a result is a Value then it must be either a ParamType or a handle to an a constant like op.

Definition at line 178 of file LinalgTransformOps.cpp.

◆ replaceForAllWithNewSignature()

static Operation* replaceForAllWithNewSignature ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp,
TilingResult tileAndFuseResult,
int64_t  resultNumber,
SmallVector< OpFoldResult > &  offsets,
SmallVector< OpFoldResult > &  sizes 
)
static

◆ tileAndFuseFirstExtractUse()

static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

Find the first "extract" user of producerOp and tile it right before its use.

The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp, return a new containingOp with results of the fused op appended to results of the containingOp or nullptr if there are no dominated uses.

Definition at line 689 of file LinalgTransformOps.cpp.

References DBGS, diag(), mlir::failed(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().

◆ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()

static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail.

Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails.

Definition at line 771 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), mlir::Operation::setOperand(), and mlir::succeeded().

◆ tryApply()

template<typename PatternTy , typename... Args>
static FailureOr<LinalgOp> tryApply ( Operation operation,
Args &&...  args 
)
static

Attempts to apply the pattern specified as template argument to the given operation.

The pattern is expected to have a returningMatchAndRewrite function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.

Definition at line 64 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [1/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
ArrayRef< OpFoldResult ofrs 
)
static

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.

Definition at line 92 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [2/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
Value  packedHandle 
)
static

Definition at line 144 of file LinalgTransformOps.cpp.