18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/InterleavedRange.h"
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 if (!isa<linalg::LinalgOp>(current)) {
36 if (getFailurePropagationMode().value_or(
37 FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() <<
"expected a Linalg op";
42 LLVM_DEBUG(
DBGS() <<
"optional nested matcher expected a Linalg op");
48 auto scope = state.make_region_scope(getBodyRegion());
49 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
54 for (
Operation &nested : getBody()->without_terminator()) {
56 state.applyTransform(cast<TransformOpInterface>(nested));
57 if (
diag.isDefiniteFailure())
63 assert(
diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
78 LLVM_DEBUG(
DBGS() <<
"optional nested matcher failed: " <<
diag.getMessage()
83 getBody()->getTerminator()->getOpOperands()) {
84 Operation *definingOp = terminatorOperand.get().getDefiningOp();
87 if (definingOp->
getBlock() != getBody())
92 undefinedOperands.push_back(&terminatorOperand);
96 auto filtered = llvm::make_filter_range(
97 getBody()->getTerminator()->getOpOperands(), [&](
OpOperand &opOperand) {
98 return !llvm::is_contained(undefinedOperands, &opOperand);
101 filtered, [](
OpOperand &opOperand) {
return opOperand.
get(); }));
103 for (
auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
116 void transform::MatchStructuredOp::getEffects(
124 if (getBody()->getNumArguments() != 1)
125 return emitOpError() <<
"expected one body argument";
126 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).
getType())) {
127 return emitOpError() <<
"expected body argument to implement "
128 "TransformHandleTypeInterface";
130 for (
Operation &nested : getBody()->without_terminator()) {
131 if (isa<MatchOpInterface>(nested))
135 <<
"expects nested operations to implement MatchOpInterface";
136 diag.attachNote(nested.getLoc()) <<
"offending operation";
148 if (!isa_and_nonnull<MatchStructuredOp>(op->
getParentOp())) {
149 return op->
emitOpError() <<
"expects parent op to be '"
150 << MatchStructuredOp::getOperationName() <<
"'";
161 <<
"expected predicate to apply to the surrounding structured op";
173 auto linalgOp = cast<linalg::LinalgOp>(current);
174 if (std::optional<uint64_t> position = getReductionPosition()) {
178 return emitSilenceableError() <<
"could not match reduction";
180 if (combinerOps.size() != 1) {
181 return emitSilenceableError() <<
"reduction combiner is not a single op";
185 if (getPassthrough()) {
188 return emitSilenceableError() <<
"not a passthrough";
192 if (getElementwise()) {
194 return emitSilenceableError() <<
"not elementwise";
197 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
200 llvm::raw_string_ostream os(message);
204 return elem->
getName().getStringRef() ==
205 cast<StringAttr>((*contractionOps)[0]).getValue() &&
207 cast<StringAttr>((*contractionOps)[1]).getValue();
212 return emitSilenceableError() <<
"contraction: " << message;
218 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219 getElementwise() + getContraction().has_value();
221 if (numOptions > 1) {
222 StringAttr attributeNames[] = {
223 getReductionPositionAttrName(), getPassthroughAttrName(),
224 getElementwiseAttrName(), getContractionAttrName()};
225 return emitOpError() <<
"only one of {" << llvm::interleaved(attributeNames)
229 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
230 if (contractionAttr->size() != 2) {
231 return emitOpError() <<
"expects " << getContractionAttrName()
232 <<
" to contain two elements";
243 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
246 FailureOr<linalg::ContractionDimensions> contractionDims =
248 if (failed(contractionDims))
249 return emitSilenceableError() <<
"could not infer contraction dimensions";
254 return llvm::to_vector(
255 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
256 return builder.getI64IntegerAttr(value);
259 results.
setParams(cast<OpResult>(getBatch()),
260 makeI64Attrs(contractionDims->batch));
261 results.
setParams(cast<OpResult>(
getM()), makeI64Attrs(contractionDims->m));
262 results.
setParams(cast<OpResult>(
getN()), makeI64Attrs(contractionDims->n));
263 results.
setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
272 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
275 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
277 if (failed(convolutionDims))
278 return emitSilenceableError() <<
"could not infer convolution dimensions";
283 return llvm::to_vector(
284 llvm::map_range(values, [&](
unsigned value) ->
Attribute {
285 return builder.getI64IntegerAttr(value);
288 results.
setParams(cast<OpResult>(getBatch()),
289 makeI64Attrs(convolutionDims->batch));
290 results.
setParams(cast<OpResult>(getOutputImage()),
291 makeI64Attrs(convolutionDims->outputImage));
292 results.
setParams(cast<OpResult>(getOutputChannel()),
293 makeI64Attrs(convolutionDims->outputChannel));
294 results.
setParams(cast<OpResult>(getFilterLoop()),
295 makeI64Attrs(convolutionDims->filterLoop));
296 results.
setParams(cast<OpResult>(getInputChannel()),
297 makeI64Attrs(convolutionDims->inputChannel));
298 results.
setParams(cast<OpResult>(getDepth()),
299 makeI64Attrs(convolutionDims->depth));
302 return llvm::to_vector(
303 llvm::map_range(values, [&](int64_t value) ->
Attribute {
304 return builder.getI64IntegerAttr(value);
307 results.
setParams(cast<OpResult>(getStrides()),
308 makeI64AttrsFromI64(convolutionDims->strides));
309 results.
setParams(cast<OpResult>(getDilations()),
310 makeI64AttrsFromI64(convolutionDims->dilations));
326 const char *message) {
327 for (int64_t value : list) {
328 if (llvm::any_of(reference, [&](
unsigned ref) {
329 return static_cast<int64_t
>(ref) == value;
345 auto linalgOp = cast<linalg::LinalgOp>(current);
348 if (!
diag.succeeded())
352 if (getParallel() || getReduction()) {
355 linalgOp.getParallelDims(reference);
356 else if (getReduction())
357 linalgOp.getReductionDims(reference);
361 getParallel() ?
"expects dimension #{0} to be parallel"
362 :
"expects dimension #{0} to be reduction");
363 if (!
diag.succeeded())
374 llvm::map_range(dimensions, [&](int64_t dim) ->
Attribute {
375 return builder.getI64IntegerAttr(ranges[dim]);
377 results.
setParams(cast<OpResult>(getResult()), captured);
385 getRawDimList(), op.getNumLoops(), dims);
386 if (
diag.isSilenceableFailure()) {
387 diag.attachNote(op->getLoc())
388 <<
"while considering dimensions of this payload operation";
394 if (getParallel() && getReduction()) {
395 return emitOpError() <<
"cannot request the same dimension to be both "
396 "parallel and reduction";
399 getIsInverted(), getIsAll());
407 transform::MatchStructuredElementalBitwidthOp::matchValue(
410 auto setupResult = [&](int64_t bitwidth) {
412 results.
setParams(cast<OpResult>(getResult()), {attr});
420 if (
auto shapedType = dyn_cast<ShapedType>(type)) {
421 if (shapedType.getElementType().isIntOrFloat())
422 return setupResult(shapedType.getElementTypeBitWidth());
424 return emitSilenceableError()
425 <<
"unsupported type for bitwidth extraction: " << type;
435 auto linalgOp = cast<linalg::LinalgOp>(current);
438 if (!
diag.succeeded())
442 operandMapping.reserve(positions.size());
443 for (int64_t position : positions) {
445 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
447 return emitSilenceableError() <<
"the indexing map for input #"
448 << position <<
" is not a permutation";
451 return emitSilenceableError()
452 <<
"the indexing map for input #" << position
453 <<
" is not a projected permutation";
460 if (isa<AffineMapParamType>(getResult().
getType())) {
465 Value operand = linalgOp.getDpsInputOperand(position)->get();
466 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
467 operandMapping.emplace_back(operand);
472 if (!operandProducer) {
473 return emitSilenceableError()
474 <<
"input #" << position <<
" is not produced by an operation";
476 operandMapping.emplace_back(operandProducer);
486 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
487 op.getNumDpsInputs(), positions);
488 if (
diag.isSilenceableFailure()) {
489 diag.attachNote(op->getLoc())
490 <<
"while considering DPS inputs of this payload operation";
497 template <
typename OpTy>
499 if (op.getPermutation() && op.getProjectedPermutation()) {
500 return op.emitOpError()
501 << op.getPermutationAttrName() <<
" and "
502 << op.getProjectedPermutationAttrName() <<
" are mutually exclusive";
504 if (op.getRawPositionList().size() > 1 && op.getResult()) {
505 return op.emitOpError()
506 <<
"cannot bind multiple inputs/inits to the same value";
516 getIsInverted(), getIsAll());
526 auto linalgOp = cast<linalg::LinalgOp>(current);
529 if (!
diag.succeeded())
533 operandMapping.reserve(positions.size());
534 for (int64_t position : positions) {
536 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
538 return emitSilenceableError() <<
"the indexing map for output(init) #"
539 << position <<
" is not a permutation";
542 return emitSilenceableError() <<
"the indexing map for output(init) #"
543 << position <<
" is not a permutation";
550 if (isa<AffineMapParamType>(getResult().
getType())) {
555 Value operand = linalgOp.getDpsInitOperand(position)->get();
556 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
557 operandMapping.emplace_back(operand);
562 if (!operandProducer) {
563 return emitSilenceableError() <<
"output(init) #" << position
564 <<
" is not produced by an operation";
566 operandMapping.emplace_back(operandProducer);
576 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
577 op.getNumDpsInits(), positions);
578 if (
diag.isSilenceableFailure()) {
579 diag.attachNote(op->getLoc())
580 <<
"while considering DPS inits (outputs) of this payload operation";
589 getIsInverted(), getIsAll());
597 transform::MatchStructuredNumInputsOp::matchOperation(
600 auto linalgOp = cast<linalg::LinalgOp>(current);
603 results.
setParams(cast<OpResult>(getResult()), {attr});
612 transform::MatchStructuredNumInitsOp::matchOperation(
615 auto linalgOp = cast<linalg::LinalgOp>(current);
618 results.
setParams(cast<OpResult>(getResult()), {attr});
629 auto linalgOp = cast<linalg::LinalgOp>(current);
630 int64_t numLoops = linalgOp.getNumLoops();
632 results.
setParams(cast<OpResult>(getRank()), {attr});
643 auto linalgOp = cast<linalg::LinalgOp>(op);
646 if (!
diag.succeeded())
649 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
650 if (isa<TransformValueHandleTypeInterface>(getResult().
getType())) {
651 results.
setValues(cast<OpResult>(getResult()), {result});
656 return emitSilenceableError()
657 <<
"no users of the result #" << getPosition();
661 results.
set(cast<OpResult>(getResult()), {firstUser});
665 if (!llvm::hasSingleElement(result.
getUsers())) {
666 return emitSilenceableError()
667 <<
"more than one result user with single user requested";
669 results.
set(cast<OpResult>(getResult()), {firstUser});
677 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
679 auto rawPosition =
static_cast<int64_t
>(getPosition());
680 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
681 if (position >= op.getNumDpsInits() || position < 0) {
682 return emitSilenceableError()
683 <<
"position " << rawPosition
684 <<
" overflows the number of results(ints) of the payload operation";
690 if ((getAny() || getSingle()) ^
691 isa<TransformHandleTypeInterface>(getResult().
getType())) {
692 return emitOpError() <<
"expects either the any/single keyword or the type "
693 "value handle result type";
695 if (getAny() && getSingle()) {
696 return emitOpError() <<
"'any' and 'single' are mutually exclusive";
705 void transform::MatchStructuredYieldOp::getEffects(
711 void transform::MatchStructuredYieldOp::build(
OpBuilder &builder,
716 #define GET_OP_CLASSES
717 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI64IntegerAttr(int64_t value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.