22 using namespace tensor;
29 struct ExtractSliceOpReplacementInterface
30 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
36 return {extractSliceOp.getSource()};
40 struct InsertSliceOpReplacementInterface
41 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
47 return {insertSliceOp.getSource()};
51 struct ReshapeOpReplacementInterface
52 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
60 template <
typename ConcreteOp>
61 struct ReassociativeReshapeOpReplacementInterface
62 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
131 Operation *producer = fusedOperand->get().getDefiningOp();
132 return producer && producer->
hasOneUse();
150 void transform::TypeConversionCastShapeDynamicDimsOp::
152 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
157 if (inputs.size() != 1) {
160 Value input = inputs[0];
161 if (!ignoreDynamicInfo &&
165 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
168 return builder.
create<tensor::CastOp>(loc, resultType, input).
getResult();
173 if (inputs.size() != 1) {
176 Value input = inputs[0];
177 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
180 return builder.
create<tensor::CastOp>(loc, resultType, input).
getResult();
195 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
199 <<
"could not find " << i
200 <<
"-th enclosing loop";
201 diag.attachNote(target->
getLoc()) <<
"target op";
204 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
208 FailureOr<Value> replacement = failure();
209 if (
auto padOp = dyn_cast<tensor::PadOp>(target)) {
211 }
else if (
auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
215 <<
"unsupported target op";
216 diag.attachNote(target->
getLoc()) <<
"target op";
219 if (failed(replacement)) {
221 emitSilenceableError() <<
"could not make target op loop-independent";
222 diag.attachNote(target->
getLoc()) <<
"target op";
225 rewriter.
replaceOp(target, *replacement);
226 results.
push_back(replacement->getDefiningOp());
235 class TensorTransformDialectExtension
237 TensorTransformDialectExtension> {
244 declareGeneratedDialect<affine::AffineDialect>();
245 declareGeneratedDialect<tensor::TensorDialect>();
247 registerTransformOps<
249 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
255 #define GET_OP_CLASSES
256 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
void registerTransformDialectExtension(DialectRegistry ®istry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry ®istry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
std::function< bool(OpOperand *)> ControlFoldFn
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns