MLIR  20.0.0git
Macros | Functions
LinalgTransformOps.cpp File Reference
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"

Go to the source code of this file.

Macros

#define DEBUG_TYPE   "linalg-transforms"
 
#define DBGS()   (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
#define DBGSNL()   (llvm::dbgs() << "\n")
 
#define LDBG(X)   LLVM_DEBUG(DBGS() << (X) << "\n")
 
#define DOWNSCALE(trans)
 
#define DOWNSCALE_CALL(a, b)   DownscaleSizeOneWindowed2DConvolution<a, b>
 
#define DOWNSCALE_NORMAL(a, b)   DOWNSCALE(DOWNSCALE_CALL(a, b))
 
#define GET_OP_CLASSES
 

Functions

template<typename PatternTy , typename... Args>
static FailureOr< LinalgOp > tryApply (Operation *operation, Args &&...args)
 Attempts to apply the pattern specified as template argument to the given operation. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
 Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More...
 
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle)
 
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
 When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More...
 
template<typename Range >
static LogicalResult applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
 Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More...
 
static OperationreplaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
 Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More...
 
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 Find the first "extract" user of producerOp and tile it right before its use. More...
 
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail. More...
 
static OperationcloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
 
static void printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
 
static ParseResult parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
 
template<typename RelayoutOpTy >
bool isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
 Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More...
 
static void printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
 
static ParseResult parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
 
template<typename OpTy >
DiagnosedSilenceableFailure doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
 

Macro Definition Documentation

◆ DBGS

#define DBGS ( )    (llvm::dbgs() << "[" DEBUG_TYPE "]: ")

Definition at line 54 of file LinalgTransformOps.cpp.

◆ DBGSNL

#define DBGSNL ( )    (llvm::dbgs() << "\n")

Definition at line 55 of file LinalgTransformOps.cpp.

◆ DEBUG_TYPE

#define DEBUG_TYPE   "linalg-transforms"

Definition at line 53 of file LinalgTransformOps.cpp.

◆ DOWNSCALE

#define DOWNSCALE (   trans)
Value:
{ \
FailureOr<LinalgOp> res = tryApply<trans>(target); \
if (succeeded(res)) { \
results.push_back(*res); \
return DiagnosedSilenceableFailure::success(); \
} \
}

◆ DOWNSCALE_CALL

#define DOWNSCALE_CALL (   a,
 
)    DownscaleSizeOneWindowed2DConvolution<a, b>

◆ DOWNSCALE_NORMAL

#define DOWNSCALE_NORMAL (   a,
 
)    DOWNSCALE(DOWNSCALE_CALL(a, b))

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 3753 of file LinalgTransformOps.cpp.

◆ LDBG

#define LDBG (   X)    LLVM_DEBUG(DBGS() << (X) << "\n")

Definition at line 56 of file LinalgTransformOps.cpp.

Function Documentation

◆ applyTilingToAll()

template<typename Range >
static LogicalResult applyTilingToAll ( RewriterBase rewriter,
Operation transformOp,
Range &&  payloadOps,
unsigned  numLoops,
transform::TransformResults transformResults,
function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)>  applyFn 
)
static

Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.

Definition at line 496 of file LinalgTransformOps.cpp.

References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), and mlir::OpBuilder::setInsertionPoint().

◆ cloneAndFuseFirstUse()

static Operation* cloneAndFuseFirstUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

◆ doit()

template<typename OpTy >
DiagnosedSilenceableFailure doit ( RewriterBase rewriter,
OpTy  target,
transform::ApplyToEachResultList results,
transform::TransformState state 
)

◆ isValidPackingPermutation()

template<typename RelayoutOpTy >
bool isValidPackingPermutation ( RelayoutOpTy  op,
ArrayRef< int64_t >  permutation,
OuterOrInnerPerm  outerOrInnerPerm = OuterOrInnerPerm::Outer 
)

Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op.

This is the case when thepermutationrank matches the rank expected by opandpermutationis itself a permutation vector. Return true if eitheroporpermutation` are empty to allow a simpler polymorphic implementation.

Definition at line 1590 of file LinalgTransformOps.cpp.

◆ parseContinuousTileSizeTypes()

static ParseResult parseContinuousTileSizeTypes ( OpAsmParser parser,
Type targetType,
Type tileSizesType,
Type chunkSizesType 
)
static

◆ parseMultitileSizesTypes()

static ParseResult parseMultitileSizesTypes ( OpAsmParser parser,
Type targetType,
Type lowSizeType,
Type highSizeType,
Type splitPointType 
)
static

◆ printContinuousTileSizeTypes()

static void printContinuousTileSizeTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  tile_sizes,
Type   
)
static

Definition at line 2793 of file LinalgTransformOps.cpp.

◆ printMultitileSizesTypes()

static void printMultitileSizesTypes ( OpAsmPrinter printer,
Operation op,
Type  targetType,
Type  lowSizeType,
Type  ,
Type   
)
static

Definition at line 1310 of file LinalgTransformOps.cpp.

◆ reifyMixedParamAndHandleResults()

static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults ( TransformState state,
TransformOpInterface &  transformOp,
ArrayRef< OpFoldResult mixedResults,
SmallVectorImpl< int64_t > &  reified 
)
static

When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred.

If a result is a Value then it must be either a ParamType or a handle to an a constant like op.

Definition at line 178 of file LinalgTransformOps.cpp.

◆ replaceForAllWithNewSignature()

static Operation* replaceForAllWithNewSignature ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp,
TilingResult tileAndFuseResult,
int64_t  resultNumber,
SmallVector< OpFoldResult > &  offsets,
SmallVector< OpFoldResult > &  sizes 
)
static

◆ tileAndFuseFirstExtractUse()

static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

Find the first "extract" user of producerOp and tile it right before its use.

The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp, return a new containingOp with results of the fused op appended to results of the containingOp or nullptr if there are no dominated uses.

Definition at line 688 of file LinalgTransformOps.cpp.

References DBGS, diag(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().

◆ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument()

static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument ( RewriterBase rewriter,
Diagnostic diag,
Operation producerOp,
Operation containingOp 
)
static

First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp, otherwise bail.

Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp. Return this fused op on success or nullptr if anything fails.

Definition at line 770 of file LinalgTransformOps.cpp.

References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().

◆ tryApply()

template<typename PatternTy , typename... Args>
static FailureOr<LinalgOp> tryApply ( Operation operation,
Args &&...  args 
)
static

Attempts to apply the pattern specified as template argument to the given operation.

The pattern is expected to have a returningMatchAndRewrite function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.

Definition at line 64 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [1/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
ArrayRef< OpFoldResult ofrs 
)
static

Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.

Definition at line 92 of file LinalgTransformOps.cpp.

◆ unpackSingleIndexResultPayloadOperations() [2/2]

static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations ( transform::TransformState state,
TransformOpInterface  transformOp,
SmallVector< OpFoldResult > &  result,
Value  packedHandle 
)
static

Definition at line 144 of file LinalgTransformOps.cpp.