19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 if (!isa<linalg::LinalgOp>(current)) {
36 if (getFailurePropagationMode().value_or(
37 FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() <<
"expected a Linalg op";
42 LLVM_DEBUG(
DBGS() <<
"optional nested matcher expected a Linalg op");
48 auto scope = state.make_region_scope(getBodyRegion());
49 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
54 for (
Operation &nested : getBody()->without_terminator()) {
56 state.applyTransform(cast<TransformOpInterface>(nested));
57 if (
diag.isDefiniteFailure())
63 assert(
diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
78 LLVM_DEBUG(
DBGS() <<
"optional nested matcher failed: " <<
diag.getMessage()
83 getBody()->getTerminator()->getOpOperands()) {
84 Operation *definingOp = terminatorOperand.get().getDefiningOp();
87 if (definingOp->
getBlock() != getBody())
92 undefinedOperands.push_back(&terminatorOperand);
96 auto filtered = llvm::make_filter_range(
97 getBody()->getTerminator()->getOpOperands(), [&](
OpOperand &opOperand) {
98 return !llvm::is_contained(undefinedOperands, &opOperand);
101 filtered, [](
OpOperand &opOperand) {
return opOperand.
get(); }));
103 for (
auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
116 void transform::MatchStructuredOp::getEffects(
124 if (getBody()->getNumArguments() != 1)
125 return emitOpError() <<
"expected one body argument";
126 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).
getType())) {
127 return emitOpError() <<
"expected body argument to implement "
128 "TransformHandleTypeInterface";
130 for (
Operation &nested : getBody()->without_terminator()) {
131 if (isa<MatchOpInterface>(nested))
135 <<
"expects nested operations to implement MatchOpInterface";
136 diag.attachNote(nested.getLoc()) <<
"offending operation";
148 if (!isa_and_nonnull<MatchStructuredOp>(op->
getParentOp())) {
149 return op->
emitOpError() <<
"expects parent op to be '"
150 << MatchStructuredOp::getOperationName() <<
"'";
161 <<
"expected predicate to apply to the surrounding structured op";
173 auto linalgOp = cast<linalg::LinalgOp>(current);
174 if (std::optional<uint64_t> position = getReductionPosition()) {
178 return emitSilenceableError() <<
"could not match reduction";
180 if (combinerOps.size() != 1) {
181 return emitSilenceableError() <<
"reduction combiner is not a single op";
185 if (getPassthrough()) {
188 return emitSilenceableError() <<
"not a passthrough";
192 if (getElementwise()) {
194 return emitSilenceableError() <<
"not elementwise";
197 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
200 llvm::raw_string_ostream os(message);
204 return elem->
getName().getStringRef() ==
205 cast<StringAttr>((*contractionOps)[0]).getValue() &&
207 cast<StringAttr>((*contractionOps)[1]).getValue();
212 return emitSilenceableError() <<
"contraction: " << message;
218 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219 getElementwise() + getContraction().has_value();
221 if (numOptions > 1) {
222 std::string attributeNames;
223 llvm::raw_string_ostream os(attributeNames);
225 getPassthroughAttrName(),
226 getElementwiseAttrName(),
227 getContractionAttrName()},
229 return emitOpError() <<
"only one of {" << attributeNames <<
"} is allowed";
232 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
233 if (contractionAttr->size() != 2) {
234 return emitOpError() <<
"expects " << getContractionAttrName()
235 <<
" to contain two elements";
246 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
249 FailureOr<linalg::ContractionDimensions> contractionDims =
251 if (failed(contractionDims))
252 return emitSilenceableError() <<
"could not infer contraction dimensions";
257 return llvm::to_vector(
258 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
259 return builder.getI64IntegerAttr(value);
262 results.
setParams(cast<OpResult>(getBatch()),
263 makeI64Attrs(contractionDims->batch));
264 results.
setParams(cast<OpResult>(
getM()), makeI64Attrs(contractionDims->m));
265 results.
setParams(cast<OpResult>(
getN()), makeI64Attrs(contractionDims->n));
266 results.
setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
275 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
278 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
280 if (failed(convolutionDims))
281 return emitSilenceableError() <<
"could not infer convolution dimensions";
286 return llvm::to_vector(
287 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
288 return builder.getI64IntegerAttr(value);
291 results.
setParams(cast<OpResult>(getBatch()),
292 makeI64Attrs(convolutionDims->batch));
293 results.
setParams(cast<OpResult>(getOutputImage()),
294 makeI64Attrs(convolutionDims->outputImage));
295 results.
setParams(cast<OpResult>(getOutputChannel()),
296 makeI64Attrs(convolutionDims->outputChannel));
297 results.
setParams(cast<OpResult>(getFilterLoop()),
298 makeI64Attrs(convolutionDims->filterLoop));
299 results.
setParams(cast<OpResult>(getInputChannel()),
300 makeI64Attrs(convolutionDims->inputChannel));
301 results.
setParams(cast<OpResult>(getDepth()),
302 makeI64Attrs(convolutionDims->depth));
305 return llvm::to_vector(
306 llvm::map_range(values, [&](int64_t value) ->
Attribute {
307 return builder.getI64IntegerAttr(value);
310 results.
setParams(cast<OpResult>(getStrides()),
311 makeI64AttrsFromI64(convolutionDims->strides));
312 results.
setParams(cast<OpResult>(getDilations()),
313 makeI64AttrsFromI64(convolutionDims->dilations));
329 const char *message) {
330 for (int64_t value : list) {
331 if (llvm::any_of(reference, [&](
unsigned ref) {
332 return static_cast<int64_t
>(ref) == value;
348 auto linalgOp = cast<linalg::LinalgOp>(current);
351 if (!
diag.succeeded())
355 if (getParallel() || getReduction()) {
358 linalgOp.getParallelDims(reference);
359 else if (getReduction())
360 linalgOp.getReductionDims(reference);
364 getParallel() ?
"expects dimension #{0} to be parallel"
365 :
"expects dimension #{0} to be reduction");
366 if (!
diag.succeeded())
377 llvm::map_range(dimensions, [&](int64_t dim) ->
Attribute {
378 return builder.getI64IntegerAttr(ranges[dim]);
380 results.
setParams(cast<OpResult>(getResult()), captured);
388 getRawDimList(), op.getNumLoops(), dims);
389 if (
diag.isSilenceableFailure()) {
390 diag.attachNote(op->getLoc())
391 <<
"while considering dimensions of this payload operation";
397 if (getParallel() && getReduction()) {
398 return emitOpError() <<
"cannot request the same dimension to be both "
399 "parallel and reduction";
402 getIsInverted(), getIsAll());
410 transform::MatchStructuredElementalBitwidthOp::matchValue(
413 auto setupResult = [&](int64_t bitwidth) {
415 results.
setParams(cast<OpResult>(getResult()), {attr});
423 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
424 if (shapedType.getElementType().isIntOrFloat())
425 return setupResult(shapedType.getElementTypeBitWidth());
427 return emitSilenceableError()
428 <<
"unsupported type for bitwidth extraction: " << type;
438 auto linalgOp = cast<linalg::LinalgOp>(current);
441 if (!
diag.succeeded())
445 operandMapping.reserve(positions.size());
446 for (int64_t position : positions) {
448 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
450 return emitSilenceableError() <<
"the indexing map for input #"
451 << position <<
" is not a permutation";
454 return emitSilenceableError()
455 <<
"the indexing map for input #" << position
456 <<
" is not a projected permutation";
463 if (isa<AffineMapParamType>(getResult().
getType())) {
468 Value operand = linalgOp.getDpsInputOperand(position)->get();
469 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
470 operandMapping.emplace_back(operand);
475 if (!operandProducer) {
476 return emitSilenceableError()
477 <<
"input #" << position <<
" is not produced by an operation";
479 operandMapping.emplace_back(operandProducer);
489 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
490 op.getNumDpsInputs(), positions);
491 if (
diag.isSilenceableFailure()) {
492 diag.attachNote(op->getLoc())
493 <<
"while considering DPS inputs of this payload operation";
500 template <
typename OpTy>
502 if (op.getPermutation() && op.getProjectedPermutation()) {
503 return op.emitOpError()
504 << op.getPermutationAttrName() <<
" and "
505 << op.getProjectedPermutationAttrName() <<
" are mutually exclusive";
507 if (op.getRawPositionList().size() > 1 && op.getResult()) {
508 return op.emitOpError()
509 <<
"cannot bind multiple inputs/inits to the same value";
519 getIsInverted(), getIsAll());
529 auto linalgOp = cast<linalg::LinalgOp>(current);
532 if (!
diag.succeeded())
536 operandMapping.reserve(positions.size());
537 for (int64_t position : positions) {
539 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
541 return emitSilenceableError() <<
"the indexing map for output(init) #"
542 << position <<
" is not a permutation";
545 return emitSilenceableError() <<
"the indexing map for output(init) #"
546 << position <<
" is not a permutation";
553 if (isa<AffineMapParamType>(getResult().
getType())) {
558 Value operand = linalgOp.getDpsInitOperand(position)->get();
559 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
560 operandMapping.emplace_back(operand);
565 if (!operandProducer) {
566 return emitSilenceableError() <<
"output(init) #" << position
567 <<
" is not produced by an operation";
569 operandMapping.emplace_back(operandProducer);
579 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
580 op.getNumDpsInits(), positions);
581 if (
diag.isSilenceableFailure()) {
582 diag.attachNote(op->getLoc())
583 <<
"while considering DPS inits (outputs) of this payload operation";
592 getIsInverted(), getIsAll());
600 transform::MatchStructuredNumInputsOp::matchOperation(
603 auto linalgOp = cast<linalg::LinalgOp>(current);
606 results.
setParams(cast<OpResult>(getResult()), {attr});
615 transform::MatchStructuredNumInitsOp::matchOperation(
618 auto linalgOp = cast<linalg::LinalgOp>(current);
621 results.
setParams(cast<OpResult>(getResult()), {attr});
632 auto linalgOp = cast<linalg::LinalgOp>(current);
633 int64_t numLoops = linalgOp.getNumLoops();
635 results.
setParams(cast<OpResult>(getRank()), {attr});
646 auto linalgOp = cast<linalg::LinalgOp>(op);
649 if (!
diag.succeeded())
652 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
653 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
654 results.
setValues(cast<OpResult>(getResult()), {result});
659 return emitSilenceableError()
660 <<
"no users of the result #" << getPosition();
664 results.
set(cast<OpResult>(getResult()), {firstUser});
668 if (!llvm::hasSingleElement(result.
getUsers())) {
669 return emitSilenceableError()
670 <<
"more than one result user with single user requested";
672 results.
set(cast<OpResult>(getResult()), {firstUser});
680 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
682 auto rawPosition =
static_cast<int64_t
>(getPosition());
683 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
684 if (position >= op.getNumDpsInits() || position < 0) {
685 return emitSilenceableError()
686 <<
"position " << rawPosition
687 <<
" overflows the number of results(ints) of the payload operation";
693 if ((getAny() || getSingle()) ^
694 isa<TransformHandleTypeInterface>(getResult().
getType())) {
695 return emitOpError() <<
"expects either the any/single keyword or the type "
696 "value handle result type";
698 if (getAny() && getSingle()) {
699 return emitOpError() <<
"'any' and 'single' are mutually exclusive";
708 void transform::MatchStructuredYieldOp::getEffects(
714 void transform::MatchStructuredYieldOp::build(
OpBuilder &builder,
719 #define GET_OP_CLASSES
720 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.