MLIR
20.0.0git
|
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <type_traits>
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
Go to the source code of this file.
Macros | |
#define | DEBUG_TYPE "linalg-transforms" |
#define | DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
#define | DBGSNL() (llvm::dbgs() << "\n") |
#define | LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") |
#define | DOWNSCALE(trans) |
#define | DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b> |
#define | DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define | GET_OP_CLASSES |
Functions | |
template<typename PatternTy , typename... Args> | |
static FailureOr< LinalgOp > | tryApply (Operation *operation, Args &&...args) |
Attempts to apply the pattern specified as template argument to the given operation. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs) |
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value. More... | |
static DiagnosedSilenceableFailure | unpackSingleIndexResultPayloadOperations (transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, Value packedHandle) |
static DiagnosedSilenceableFailure | reifyMixedParamAndHandleResults (TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified) |
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically inferred. More... | |
template<typename Range > | |
static LogicalResult | applyTilingToAll (RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn) |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops. More... | |
static Operation * | replaceForAllWithNewSignature (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes) |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op. More... | |
static std::tuple< SmallVector< Operation * >, Operation * > | tileAndFuseFirstExtractUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
Find the first "extract" user of producerOp and tile it right before its use. More... | |
static SmallVector< Operation * > | tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp , otherwise bail. More... | |
static Operation * | cloneAndFuseFirstUse (RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) |
static void | printMultitileSizesTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type) |
static ParseResult | parseMultitileSizesTypes (OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType) |
template<typename RelayoutOpTy > | |
bool | isValidPackingPermutation (RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer) |
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Outer) or inner_dims_pos (OuterOrInnerPerm::Inner) of the tensor.pack or tensor.unpack op. More... | |
static void | printContinuousTileSizeTypes (OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type) |
static ParseResult | parseContinuousTileSizeTypes (OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType) |
template<typename OpTy > | |
DiagnosedSilenceableFailure | doit (RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state) |
#define DBGS | ( | ) | (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
Definition at line 54 of file LinalgTransformOps.cpp.
#define DBGSNL | ( | ) | (llvm::dbgs() << "\n") |
Definition at line 55 of file LinalgTransformOps.cpp.
#define DEBUG_TYPE "linalg-transforms" |
Definition at line 53 of file LinalgTransformOps.cpp.
#define DOWNSCALE | ( | trans | ) |
#define DOWNSCALE_CALL | ( | a, | |
b | |||
) | DownscaleSizeOneWindowed2DConvolution<a, b> |
#define DOWNSCALE_NORMAL | ( | a, | |
b | |||
) | DOWNSCALE(DOWNSCALE_CALL(a, b)) |
#define GET_OP_CLASSES |
Definition at line 3753 of file LinalgTransformOps.cpp.
#define LDBG | ( | X | ) | LLVM_DEBUG(DBGS() << (X) << "\n") |
Definition at line 56 of file LinalgTransformOps.cpp.
|
static |
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the created tile loops.
Definition at line 496 of file LinalgTransformOps.cpp.
References mlir::Operation::emitError(), mlir::RewriterBase::eraseOp(), mlir::Operation::getOpResult(), mlir::RewriterBase::replaceAllUsesWith(), mlir::transform::TransformResults::set(), and mlir::OpBuilder::setInsertionPoint().
|
static |
Definition at line 872 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::Operation::getOpResults(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::isProperAncestor(), mlir::RewriterBase::modifyOpInPlace(), and mlir::OpBuilder::setInsertionPoint().
DiagnosedSilenceableFailure doit | ( | RewriterBase & | rewriter, |
OpTy | target, | ||
transform::ApplyToEachResultList & | results, | ||
transform::TransformState & | state | ||
) |
Definition at line 3593 of file LinalgTransformOps.cpp.
Referenced by llvm::cast_convert_val< T, ::mlir::Dialect *, ::mlir::Dialect * >::doit().
bool isValidPackingPermutation | ( | RelayoutOpTy | op, |
ArrayRef< int64_t > | permutation, | ||
OuterOrInnerPerm | outerOrInnerPerm = OuterOrInnerPerm::Outer |
||
) |
Return true if permutation
is a valid permutation of the outer_dims_perm
(case OuterOrInnerPerm::Outer) or inner_dims_pos
(OuterOrInnerPerm::Inner) of the tensor.pack
or tensor.unpack
op.
This is the case when thepermutationrank matches the rank expected by
opand
permutationis itself a permutation vector. Return true if either
opor
permutation` are empty to allow a simpler polymorphic implementation.
Definition at line 1590 of file LinalgTransformOps.cpp.
|
static |
Definition at line 2799 of file LinalgTransformOps.cpp.
References mlir::AsmParser::emitError(), mlir::AsmParser::getCurrentLocation(), and mlir::AsmParser::parseType().
|
static |
Definition at line 1316 of file LinalgTransformOps.cpp.
References mlir::AsmParser::emitError(), mlir::AsmParser::getCurrentLocation(), and mlir::AsmParser::parseType().
|
static |
Definition at line 2793 of file LinalgTransformOps.cpp.
|
static |
Definition at line 1310 of file LinalgTransformOps.cpp.
|
static |
When possible, converts each OpFoldResult
in mixedResult
to an integer if the value can be statically inferred.
If a result is a Value
then it must be either a ParamType
or a handle to an a constant like op.
Definition at line 178 of file LinalgTransformOps.cpp.
|
static |
Add new operands to the forall op for users of the producerOp that are dominated by the containing scf.forall op.
Definition at line 605 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::create(), mlir::DominanceInfo::dominates(), mlir::RewriterBase::eraseBlock(), mlir::Operation::getResult(), mlir::Value::getUsers(), mlir::Operation::isAncestor(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by tileAndFuseFirstExtractUse().
|
static |
Find the first "extract" user of producerOp
and tile it right before its use.
The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails. If tiled op has uses that are dominated by containingOp
, return a new containingOp
with results of the fused op appended to results of the containingOp
or nullptr if there are no dominated uses.
Definition at line 688 of file LinalgTransformOps.cpp.
References DBGS, diag(), mlir::Operation::getLoc(), replaceForAllWithNewSignature(), mlir::RewriterBase::replaceOp(), and mlir::OpBuilder::setInsertionPoint().
|
static |
First, find the first "scf::ForallOp" user of producerOp
and ensure it is exactly the containingOp
, otherwise bail.
Then, find the first "extract" user of the tied block argument and tile it right before its "extract" use. The tiled op is fused under the containingOp
. Return this fused op on success or nullptr if anything fails.
Definition at line 770 of file LinalgTransformOps.cpp.
References mlir::OpBuilder::clone(), DBGS, diag(), mlir::RewriterBase::eraseOp(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Operation::getLoc(), mlir::OpOperand::getOperandNumber(), mlir::tensor::getOrCreateDestinations(), mlir::Value::getUsers(), mlir::IRMapping::map(), mlir::RewriterBase::modifyOpInPlace(), mlir::RewriterBase::replaceOp(), mlir::OpBuilder::setInsertionPoint(), and mlir::Operation::setOperand().
|
static |
Attempts to apply the pattern specified as template argument to the given operation.
The pattern is expected to have a returningMatchAndRewrite
function that returns the "main" result or failure. Returns failure if the pattern failed to apply. Extra arguments are forwarded to the pattern constructor.
Definition at line 64 of file LinalgTransformOps.cpp.
|
static |
Assuming that ofr
is an index attr or a param of index type or a transform dialect handle mapped to exactly one op with one index result, return that value.
Definition at line 92 of file LinalgTransformOps.cpp.
|
static |
Definition at line 144 of file LinalgTransformOps.cpp.