MLIR 22.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "llvm/Support/DebugLog.h"
19#include "llvm/Support/FormatVariadic.h"
20#include "llvm/Support/InterleavedRange.h"
21
22using namespace mlir;
23
24#define DEBUG_TYPE "linalg-transforms"
25
26//===----------------------------------------------------------------------===//
27// StructuredMatchOp
28//===----------------------------------------------------------------------===//
29
30DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
31 Operation *current, transform::TransformResults &results,
33 // First, check if the payload operation is a structured Linalg operation.
34 if (!isa<linalg::LinalgOp>(current)) {
35 if (getFailurePropagationMode().value_or(
36 FailurePropagationMode::Propagate) ==
37 FailurePropagationMode::Propagate) {
38 return emitSilenceableError() << "expected a Linalg op";
39 }
40 // If errors are suppressed, succeed and set all results to empty lists.
41 LDBG() << "optional nested matcher expected a Linalg op";
42 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
44 }
45
46 // Bind `current` to the block argument.
47 auto scope = state.make_region_scope(getBodyRegion());
48 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
49 MappedValue(current)))) {
51 }
52
53 for (Operation &nested : getBody()->without_terminator()) {
55 state.applyTransform(cast<TransformOpInterface>(nested));
56 if (diag.isDefiniteFailure())
57 return diag;
58 if (diag.succeeded())
59 continue;
60
61 // If propagating errors, do this immediately.
62 assert(diag.isSilenceableFailure());
63 if (getFailurePropagationMode().value_or(
64 FailurePropagationMode::Propagate) ==
65 FailurePropagationMode::Propagate) {
66 return diag;
67 }
68
69 // If suppressing errors, print the message into the debug stream before
70 // silencing it. Then set all results value that are already known.
71 // Results come from the terminator operands, which may be defined in the
72 // (single) block of this operation or above it. When they are defined
73 // above, they are known to be mapped at this point per SSA dominance.
74 // When they are defined in this block, we additionally check if we have
75 // already applied the operation that defines them. If not, the
76 // corresponding results will be set to empty lists.
77 LDBG() << "optional nested matcher failed: " << diag.getMessage();
78 (void)diag.silence();
79 SmallVector<OpOperand *> undefinedOperands;
80 for (OpOperand &terminatorOperand :
81 getBody()->getTerminator()->getOpOperands()) {
82 Operation *definingOp = terminatorOperand.get().getDefiningOp();
83 if (!definingOp)
84 continue;
85 if (definingOp->getBlock() != getBody())
86 continue;
87 if (definingOp->isBeforeInBlock(&nested))
88 continue;
89
90 undefinedOperands.push_back(&terminatorOperand);
91 }
92
94 auto filtered = llvm::make_filter_range(
95 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
96 return !llvm::is_contained(undefinedOperands, &opOperand);
97 });
98 SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
99 filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
100 detail::prepareValueMappings(mappings, definedOperands, state);
101 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
102 results.setMappedValues(getResults()[operand.getOperandNumber()],
103 mapping);
104 }
105 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
107 }
108
109 // Set the results.
110 detail::forwardTerminatorOperands(getBody(), state, results);
112}
113
114void transform::MatchStructuredOp::getEffects(
116 onlyReadsHandle(getCurrentMutable(), effects);
117 onlyReadsPayload(effects);
118 producesHandle(getOperation()->getOpResults(), effects);
119}
120
121LogicalResult transform::MatchStructuredOp::verify() {
122 if (getBody()->getNumArguments() != 1)
123 return emitOpError() << "expected one body argument";
124 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
125 return emitOpError() << "expected body argument to implement "
126 "TransformHandleTypeInterface";
127 }
128 for (Operation &nested : getBody()->without_terminator()) {
129 if (isa<MatchOpInterface>(nested))
130 continue;
133 << "expects nested operations to implement MatchOpInterface";
134 diag.attachNote(nested.getLoc()) << "offending operation";
135 return diag;
136 }
137 return success();
138}
139
140//===----------------------------------------------------------------------===//
141// StructuredOpPredicateOpTrait
142//===----------------------------------------------------------------------===//
143
145 Operation *op, Value structuredOpHandle) {
146 if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
147 return op->emitOpError() << "expects parent op to be '"
148 << MatchStructuredOp::getOperationName() << "'";
149 }
150
151 // Bail out here, let the verifier of the parent complain.
152 Operation *parent = op->getParentOp();
153 if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
154 parent->getRegion(0).front().getNumArguments() < 1)
155 return success();
156
157 if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
158 return op->emitOpError()
159 << "expected predicate to apply to the surrounding structured op";
160 }
161 return success();
162}
163
164//===----------------------------------------------------------------------===//
165// MatchStructuredBodyOp
166//===----------------------------------------------------------------------===//
167
168DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
169 Operation *current, transform::TransformResults &results,
171 auto linalgOp = cast<linalg::LinalgOp>(current);
172 if (std::optional<uint64_t> position = getReductionPosition()) {
173 SmallVector<Operation *> combinerOps;
174 if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
175 combinerOps)) {
176 return emitSilenceableError() << "could not match reduction";
177 }
178 if (combinerOps.size() != 1) {
179 return emitSilenceableError() << "reduction combiner is not a single op";
180 }
182 }
183 if (getPassthrough()) {
184 Block &body = linalgOp->getRegion(0).front();
185 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
186 return emitSilenceableError() << "not a passthrough";
187 }
189 }
190 if (getElementwise()) {
191 if (!isElementwise(linalgOp))
192 return emitSilenceableError() << "not elementwise";
194 }
195 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
196 Block &body = linalgOp->getRegion(0).front();
197 std::string message;
198 llvm::raw_string_ostream os(message);
200 body,
201 [&](Operation *elem, Operation *red) {
202 return elem->getName().getStringRef() ==
203 cast<StringAttr>((*contractionOps)[0]).getValue() &&
204 red->getName().getStringRef() ==
205 cast<StringAttr>((*contractionOps)[1]).getValue();
206 },
207 os);
208 if (result)
210 return emitSilenceableError() << "contraction: " << message;
211 }
212 return emitDefiniteFailure() << "unknown body condition";
213}
214
215LogicalResult transform::MatchStructuredBodyOp::verify() {
216 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
217 getElementwise() + getContraction().has_value();
218
219 if (numOptions > 1) {
220 StringAttr attributeNames[] = {
221 getReductionPositionAttrName(), getPassthroughAttrName(),
222 getElementwiseAttrName(), getContractionAttrName()};
223 return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
224 << "} is allowed";
225 }
226
227 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
228 if (contractionAttr->size() != 2) {
229 return emitOpError() << "expects " << getContractionAttrName()
230 << " to contain two elements";
231 }
232 }
233 return success();
234}
235
236//===----------------------------------------------------------------------===//
237// MatchStructuredClassifyContractionDimsOp
238//===----------------------------------------------------------------------===//
239
241transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
242 Operation *current, transform::TransformResults &results,
244 FailureOr<linalg::ContractionDimensions> contractionDims =
245 linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
246 if (failed(contractionDims))
247 return emitSilenceableError() << "could not infer contraction dimensions";
248
249 MLIRContext *context = current->getContext();
250 Builder builder(context);
251 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
252 return llvm::to_vector(
253 llvm::map_range(values, [&](unsigned value) -> Attribute {
254 return builder.getI64IntegerAttr(value);
255 }));
256 };
257 results.setParams(cast<OpResult>(getBatch()),
258 makeI64Attrs(contractionDims->batch));
259 results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
260 results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
261 results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
263}
264
265//===----------------------------------------------------------------------===//
266// MatchStructuredClassifyConvolutionDimsOp
267//===----------------------------------------------------------------------===//
268
270transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
271 Operation *current, transform::TransformResults &results,
273 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
274 linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
275 if (failed(convolutionDims))
276 return emitSilenceableError() << "could not infer convolution dimensions";
277
278 MLIRContext *context = current->getContext();
279 Builder builder(context);
280 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
281 return llvm::to_vector(
282 llvm::map_range(values, [&](unsigned value) -> Attribute {
283 return builder.getI64IntegerAttr(value);
284 }));
285 };
286 results.setParams(cast<OpResult>(getBatch()),
287 makeI64Attrs(convolutionDims->batch));
288 results.setParams(cast<OpResult>(getOutputImage()),
289 makeI64Attrs(convolutionDims->outputImage));
290 results.setParams(cast<OpResult>(getOutputChannel()),
291 makeI64Attrs(convolutionDims->outputChannel));
292 results.setParams(cast<OpResult>(getFilterLoop()),
293 makeI64Attrs(convolutionDims->filterLoop));
294 results.setParams(cast<OpResult>(getInputChannel()),
295 makeI64Attrs(convolutionDims->inputChannel));
296 results.setParams(cast<OpResult>(getDepth()),
297 makeI64Attrs(convolutionDims->depth));
298
299 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
300 return llvm::to_vector(
301 llvm::map_range(values, [&](int64_t value) -> Attribute {
302 return builder.getI64IntegerAttr(value);
303 }));
304 };
305 results.setParams(cast<OpResult>(getStrides()),
306 makeI64AttrsFromI64(convolutionDims->strides));
307 results.setParams(cast<OpResult>(getDilations()),
308 makeI64AttrsFromI64(convolutionDims->dilations));
310}
311
312//===----------------------------------------------------------------------===//
313// Utilities for structured match predicates.
314//===----------------------------------------------------------------------===//
315
316/// Checks if all values from `list` are also contained in `reference`. Returns
317/// a silenceable error with the given message at the given location when it is
318/// not the case. The error message must contain the "{0}" placeholder that
319/// will be substituted with the value from `list` that is not contained in
320/// `reference`.
323 Location loc,
324 const char *message) {
325 for (int64_t value : list) {
326 if (llvm::any_of(reference, [&](unsigned ref) {
327 return static_cast<int64_t>(ref) == value;
328 })) {
329 continue;
330 }
331 return emitSilenceableFailure(loc) << llvm::formatv(message, value);
332 }
334}
335
336//===----------------------------------------------------------------------===//
337// MatchStructuredDimOp
338//===----------------------------------------------------------------------===//
339
340DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
341 Operation *current, transform::TransformResults &results,
343 auto linalgOp = cast<linalg::LinalgOp>(current);
344 SmallVector<int64_t> dimensions;
345 DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
346 if (!diag.succeeded())
347 return diag;
348
349 // If asked to check for the kind of dimension, perform the check.
350 if (getParallel() || getReduction()) {
351 SmallVector<unsigned> reference;
352 if (getParallel())
353 linalgOp.getParallelDims(reference);
354 else if (getReduction())
355 linalgOp.getReductionDims(reference);
356
358 containsAll(reference, dimensions, getLoc(),
359 getParallel() ? "expects dimension #{0} to be parallel"
360 : "expects dimension #{0} to be reduction");
361 if (!diag.succeeded())
362 return diag;
363 }
364
365 // If not capturing, we are done here.
366 if (!getResult())
367 return diag;
368
369 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
370 Builder builder(current);
371 SmallVector<Attribute> captured = llvm::to_vector(
372 llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
373 return builder.getI64IntegerAttr(ranges[dim]);
374 }));
375 results.setParams(cast<OpResult>(getResult()), captured);
377}
378
379DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
380 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
382 expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
383 getRawDimList(), op.getNumLoops(), dims);
384 if (diag.isSilenceableFailure()) {
385 diag.attachNote(op->getLoc())
386 << "while considering dimensions of this payload operation";
387 }
388 return diag;
389}
390
391LogicalResult transform::MatchStructuredDimOp::verify() {
392 if (getParallel() && getReduction()) {
393 return emitOpError() << "cannot request the same dimension to be both "
394 "parallel and reduction";
395 }
396 return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
397 getIsInverted(), getIsAll());
398}
399
400//===----------------------------------------------------------------------===//
401// MatchStructuredElementalBitwidthOp
402//===----------------------------------------------------------------------===//
403
405transform::MatchStructuredElementalBitwidthOp::matchValue(
406 Value current, transform::TransformResults &results,
408 auto setupResult = [&](int64_t bitwidth) {
409 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
410 results.setParams(cast<OpResult>(getResult()), {attr});
412 };
413
414 Type type = current.getType();
415 if (type.isIntOrFloat())
416 return setupResult(type.getIntOrFloatBitWidth());
417
418 if (auto shapedType = dyn_cast<ShapedType>(type)) {
419 if (shapedType.getElementType().isIntOrFloat())
420 return setupResult(shapedType.getElementTypeBitWidth());
421 }
422 return emitSilenceableError()
423 << "unsupported type for bitwidth extraction: " << type;
424}
425
426//===----------------------------------------------------------------------===//
427// MatchStructuredInputOp
428//===----------------------------------------------------------------------===//
429
430DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
431 Operation *current, transform::TransformResults &results,
433 auto linalgOp = cast<linalg::LinalgOp>(current);
434 SmallVector<int64_t> positions;
435 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
436 if (!diag.succeeded())
437 return diag;
438
439 SmallVector<MappedValue> operandMapping;
440 operandMapping.reserve(positions.size());
441 for (int64_t position : positions) {
442 AffineMap indexingMap =
443 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
444 if (getPermutation() && !indexingMap.isPermutation()) {
445 return emitSilenceableError() << "the indexing map for input #"
446 << position << " is not a permutation";
447 }
448 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
449 return emitSilenceableError()
450 << "the indexing map for input #" << position
451 << " is not a projected permutation";
452 }
453
454 // If capture not requested, skip it.
455 if (!getResult())
456 continue;
457
458 if (isa<AffineMapParamType>(getResult().getType())) {
459 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
460 continue;
461 }
462
463 Value operand = linalgOp.getDpsInputOperand(position)->get();
464 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
465 operandMapping.emplace_back(operand);
466 continue;
467 }
468
469 Operation *operandProducer = operand.getDefiningOp();
470 if (!operandProducer) {
471 return emitSilenceableError()
472 << "input #" << position << " is not produced by an operation";
473 }
474 operandMapping.emplace_back(operandProducer);
475 }
476 if (getResult())
477 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
479}
480
481DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
482 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
484 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
485 op.getNumDpsInputs(), positions);
486 if (diag.isSilenceableFailure()) {
487 diag.attachNote(op->getLoc())
488 << "while considering DPS inputs of this payload operation";
489 }
490 return diag;
491}
492
493/// Verifies a matcher op for structured input or output, specifically the
494/// attributes specifying the operand positions.
495template <typename OpTy>
496LogicalResult verifyStructuredOperandOp(OpTy op) {
497 if (op.getPermutation() && op.getProjectedPermutation()) {
498 return op.emitOpError()
499 << op.getPermutationAttrName() << " and "
500 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
501 }
502 if (op.getRawPositionList().size() > 1 && op.getResult()) {
503 return op.emitOpError()
504 << "cannot bind multiple inputs/inits to the same value";
505 }
506
507 return success();
508}
509
510LogicalResult transform::MatchStructuredInputOp::verify() {
512 return failure();
513 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
514 getIsInverted(), getIsAll());
515}
516
517//===----------------------------------------------------------------------===//
518// MatchStructuredInitOp
519//===----------------------------------------------------------------------===//
520
521DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
522 Operation *current, transform::TransformResults &results,
524 auto linalgOp = cast<linalg::LinalgOp>(current);
525 SmallVector<int64_t> positions;
526 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
527 if (!diag.succeeded())
528 return diag;
529
530 SmallVector<MappedValue> operandMapping;
531 operandMapping.reserve(positions.size());
532 for (int64_t position : positions) {
533 AffineMap indexingMap =
534 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
535 if (getPermutation() && !indexingMap.isPermutation()) {
536 return emitSilenceableError() << "the indexing map for output(init) #"
537 << position << " is not a permutation";
538 }
539 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
540 return emitSilenceableError() << "the indexing map for output(init) #"
541 << position << " is not a permutation";
542 }
543
544 // If capture not requested, skip it.
545 if (!getResult())
546 continue;
547
548 if (isa<AffineMapParamType>(getResult().getType())) {
549 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
550 continue;
551 }
552
553 Value operand = linalgOp.getDpsInitOperand(position)->get();
554 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
555 operandMapping.emplace_back(operand);
556 continue;
557 }
558
559 Operation *operandProducer = operand.getDefiningOp();
560 if (!operandProducer) {
561 return emitSilenceableError() << "output(init) #" << position
562 << " is not produced by an operation";
563 }
564 operandMapping.emplace_back(operandProducer);
565 }
566 if (getResult())
567 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
569}
570
571DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
572 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
574 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
575 op.getNumDpsInits(), positions);
576 if (diag.isSilenceableFailure()) {
577 diag.attachNote(op->getLoc())
578 << "while considering DPS inits (outputs) of this payload operation";
579 }
580 return diag;
581}
582
583LogicalResult transform::MatchStructuredInitOp::verify() {
585 return failure();
586 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
587 getIsInverted(), getIsAll());
588}
589
590//===----------------------------------------------------------------------===//
591// MatchStructuredNumInputsOp
592//===----------------------------------------------------------------------===//
593
595transform::MatchStructuredNumInputsOp::matchOperation(
596 Operation *current, transform::TransformResults &results,
598 auto linalgOp = cast<linalg::LinalgOp>(current);
599 Attribute attr =
600 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
601 results.setParams(cast<OpResult>(getResult()), {attr});
603}
604
605//===----------------------------------------------------------------------===//
606// MatchStructuredNumInitsOp
607//===----------------------------------------------------------------------===//
608
610transform::MatchStructuredNumInitsOp::matchOperation(
611 Operation *current, transform::TransformResults &results,
613 auto linalgOp = cast<linalg::LinalgOp>(current);
614 Attribute attr =
615 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
616 results.setParams(cast<OpResult>(getResult()), {attr});
618}
619
620//===----------------------------------------------------------------------===//
621// MatchStructuredRankOp
622//===----------------------------------------------------------------------===//
623
624DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
625 Operation *current, transform::TransformResults &results,
627 auto linalgOp = cast<linalg::LinalgOp>(current);
628 int64_t numLoops = linalgOp.getNumLoops();
629 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
630 results.setParams(cast<OpResult>(getRank()), {attr});
632}
633
634//===----------------------------------------------------------------------===//
635// MatchStructuredResultOp
636//===----------------------------------------------------------------------===//
637
638DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
641 auto linalgOp = cast<linalg::LinalgOp>(op);
642 int64_t position;
643 DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
644 if (!diag.succeeded())
645 return diag;
646
647 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
648 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
649 results.setValues(cast<OpResult>(getResult()), {result});
651 }
652
653 if (result.getUsers().empty()) {
654 return emitSilenceableError()
655 << "no users of the result #" << getPosition();
656 }
657 Operation *firstUser = *result.getUsers().begin();
658 if (getAny()) {
659 results.set(cast<OpResult>(getResult()), {firstUser});
661 }
662 if (getSingle()) {
663 if (!llvm::hasSingleElement(result.getUsers())) {
664 return emitSilenceableError()
665 << "more than one result user with single user requested";
666 }
667 results.set(cast<OpResult>(getResult()), {firstUser});
669 }
670
671 return emitDefiniteFailure() << "unknown sub-predicate";
672}
673
675transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
676 int64_t &position) {
677 auto rawPosition = static_cast<int64_t>(getPosition());
678 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
679 if (position >= op.getNumDpsInits() || position < 0) {
680 return emitSilenceableError()
681 << "position " << rawPosition
682 << " overflows the number of results(ints) of the payload operation";
683 }
685}
686
687LogicalResult transform::MatchStructuredResultOp::verify() {
688 if ((getAny() || getSingle()) ^
689 isa<TransformHandleTypeInterface>(getResult().getType())) {
690 return emitOpError() << "expects either the any/single keyword or the type "
691 "value handle result type";
692 }
693 if (getAny() && getSingle()) {
694 return emitOpError() << "'any' and 'single' are mutually exclusive";
695 }
696 return success();
697}
698
699//===----------------------------------------------------------------------===//
700// MatchStructuredYieldOp
701//===----------------------------------------------------------------------===//
702
703void transform::MatchStructuredYieldOp::getEffects(
705 onlyReadsHandle(getHandlesMutable(), effects);
706 onlyReadsPayload(effects);
707}
708
709void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
710 OperationState &state) {
711 build(builder, state, ValueRange());
712}
713
714#define GET_OP_CLASSES
715#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This represents an operation in an abstracted form, suitable for use with the builder APIs.