MLIR 23.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/Support/DebugLog.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "llvm/Support/InterleavedRange.h"
22
23using namespace mlir;
24
25#define DEBUG_TYPE "linalg-transforms"
26
27//===----------------------------------------------------------------------===//
28// StructuredMatchOp
29//===----------------------------------------------------------------------===//
30
31DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32 Operation *current, transform::TransformResults &results,
34 // First, check if the payload operation is a structured Linalg operation.
35 if (!isa<linalg::LinalgOp>(current)) {
36 if (getFailurePropagationMode().value_or(
37 FailurePropagationMode::Propagate) ==
38 FailurePropagationMode::Propagate) {
39 return emitSilenceableError() << "expected a Linalg op";
40 }
41 // If errors are suppressed, succeed and set all results to empty lists.
42 LDBG() << "optional nested matcher expected a Linalg op";
43 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
45 }
46
47 // Bind `current` to the block argument.
48 auto scope = state.make_region_scope(getBodyRegion());
49 if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50 MappedValue(current)))) {
52 }
53
54 for (Operation &nested : getBody()->without_terminator()) {
56 state.applyTransform(cast<TransformOpInterface>(nested));
57 if (diag.isDefiniteFailure())
58 return diag;
59 if (diag.succeeded())
60 continue;
61
62 // If propagating errors, do this immediately.
63 assert(diag.isSilenceableFailure());
64 if (getFailurePropagationMode().value_or(
65 FailurePropagationMode::Propagate) ==
66 FailurePropagationMode::Propagate) {
67 return diag;
68 }
69
70 // If suppressing errors, print the message into the debug stream before
71 // silencing it. Then set all results value that are already known.
72 // Results come from the terminator operands, which may be defined in the
73 // (single) block of this operation or above it. When they are defined
74 // above, they are known to be mapped at this point per SSA dominance.
75 // When they are defined in this block, we additionally check if we have
76 // already applied the operation that defines them. If not, the
77 // corresponding results will be set to empty lists.
78 LDBG() << "optional nested matcher failed: " << diag.getMessage();
79 (void)diag.silence();
80 SmallVector<OpOperand *> undefinedOperands;
81 for (OpOperand &terminatorOperand :
82 getBody()->getTerminator()->getOpOperands()) {
83 Operation *definingOp = terminatorOperand.get().getDefiningOp();
84 if (!definingOp)
85 continue;
86 if (definingOp->getBlock() != getBody())
87 continue;
88 if (definingOp->isBeforeInBlock(&nested))
89 continue;
90
91 undefinedOperands.push_back(&terminatorOperand);
92 }
93
95 auto filtered = llvm::make_filter_range(
96 getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
97 return !llvm::is_contained(undefinedOperands, &opOperand);
98 });
99 SmallVector<Value> definedOperands = llvm::map_to_vector(
100 filtered, [](OpOperand &opOperand) { return opOperand.get(); });
101 detail::prepareValueMappings(mappings, definedOperands, state);
102 for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
103 results.setMappedValues(getResults()[operand.getOperandNumber()],
104 mapping);
105 }
106 results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
108 }
109
110 // Set the results.
111 detail::forwardTerminatorOperands(getBody(), state, results);
113}
114
115void transform::MatchStructuredOp::getEffects(
117 onlyReadsHandle(getCurrentMutable(), effects);
118 onlyReadsPayload(effects);
119 producesHandle(getOperation()->getOpResults(), effects);
120}
121
122LogicalResult transform::MatchStructuredOp::verify() {
123 if (getBody()->getNumArguments() != 1)
124 return emitOpError() << "expected one body argument";
125 if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
126 return emitOpError() << "expected body argument to implement "
127 "TransformHandleTypeInterface";
128 }
129 for (Operation &nested : getBody()->without_terminator()) {
130 if (isa<MatchOpInterface>(nested))
131 continue;
134 << "expects nested operations to implement MatchOpInterface";
135 diag.attachNote(nested.getLoc()) << "offending operation";
136 return diag;
137 }
138 return success();
139}
140
141//===----------------------------------------------------------------------===//
142// StructuredOpPredicateOpTrait
143//===----------------------------------------------------------------------===//
144
146 Operation *op, Value structuredOpHandle) {
147 if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
148 return op->emitOpError() << "expects parent op to be '"
149 << MatchStructuredOp::getOperationName() << "'";
150 }
151
152 // Bail out here, let the verifier of the parent complain.
153 Operation *parent = op->getParentOp();
154 if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
155 parent->getRegion(0).front().getNumArguments() < 1)
156 return success();
157
158 if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
159 return op->emitOpError()
160 << "expected predicate to apply to the surrounding structured op";
161 }
162 return success();
163}
164
165//===----------------------------------------------------------------------===//
166// MatchStructuredBodyOp
167//===----------------------------------------------------------------------===//
168
169DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
170 Operation *current, transform::TransformResults &results,
172 auto linalgOp = cast<linalg::LinalgOp>(current);
173 if (std::optional<uint64_t> position = getReductionPosition()) {
174 SmallVector<Operation *> combinerOps;
175 if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
176 combinerOps)) {
177 return emitSilenceableError() << "could not match reduction";
178 }
179 if (combinerOps.size() != 1) {
180 return emitSilenceableError() << "reduction combiner is not a single op";
181 }
183 }
184 if (getPassthrough()) {
185 Block &body = linalgOp->getRegion(0).front();
186 if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
187 return emitSilenceableError() << "not a passthrough";
188 }
190 }
191 if (getElementwise()) {
192 if (!isElementwise(linalgOp))
193 return emitSilenceableError() << "not elementwise";
195 }
196 if (std::optional<ArrayAttr> contractionOps = getContraction()) {
197 Block &body = linalgOp->getRegion(0).front();
198 std::string message;
199 llvm::raw_string_ostream os(message);
201 body,
202 [&](Operation *elem, Operation *red) {
203 return elem->getName().getStringRef() ==
204 cast<StringAttr>((*contractionOps)[0]).getValue() &&
205 red->getName().getStringRef() ==
206 cast<StringAttr>((*contractionOps)[1]).getValue();
207 },
208 os);
209 if (result)
211 return emitSilenceableError() << "contraction: " << message;
212 }
213 return emitDefiniteFailure() << "unknown body condition";
214}
215
216LogicalResult transform::MatchStructuredBodyOp::verify() {
217 int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
218 getElementwise() + getContraction().has_value();
219
220 if (numOptions > 1) {
221 StringAttr attributeNames[] = {
222 getReductionPositionAttrName(), getPassthroughAttrName(),
223 getElementwiseAttrName(), getContractionAttrName()};
224 return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
225 << "} is allowed";
226 }
227
228 if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
229 if (contractionAttr->size() != 2) {
230 return emitOpError() << "expects " << getContractionAttrName()
231 << " to contain two elements";
232 }
233 }
234 return success();
235}
236
237//===----------------------------------------------------------------------===//
238// MatchStructuredClassifyContractionDimsOp
239//===----------------------------------------------------------------------===//
240
242transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
243 Operation *current, transform::TransformResults &results,
245 FailureOr<linalg::ContractionDimensions> contractionDims =
246 linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
247 if (failed(contractionDims))
248 return emitSilenceableError() << "could not infer contraction dimensions";
249
250 MLIRContext *context = current->getContext();
251 Builder builder(context);
252 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
253 return llvm::map_to_vector(values, [&](unsigned value) -> Attribute {
254 return builder.getI64IntegerAttr(value);
255 });
256 };
257 results.setParams(cast<OpResult>(getBatch()),
258 makeI64Attrs(contractionDims->batch));
259 results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
260 results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
261 results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
263}
264
265//===----------------------------------------------------------------------===//
266// MatchStructuredClassifyConvolutionDimsOp
267//===----------------------------------------------------------------------===//
268
270transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
271 Operation *current, transform::TransformResults &results,
273 FailureOr<linalg::ConvolutionDimensions> convolutionDims =
274 linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
275 if (failed(convolutionDims))
276 return emitSilenceableError() << "could not infer convolution dimensions";
277
278 MLIRContext *context = current->getContext();
279 Builder builder(context);
280 auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
281 return llvm::map_to_vector(values, [&](unsigned value) -> Attribute {
282 return builder.getI64IntegerAttr(value);
283 });
284 };
285 results.setParams(cast<OpResult>(getBatch()),
286 makeI64Attrs(convolutionDims->batch));
287 results.setParams(cast<OpResult>(getOutputImage()),
288 makeI64Attrs(convolutionDims->outputImage));
289 results.setParams(cast<OpResult>(getOutputChannel()),
290 makeI64Attrs(convolutionDims->outputChannel));
291 results.setParams(cast<OpResult>(getFilterLoop()),
292 makeI64Attrs(convolutionDims->filterLoop));
293 results.setParams(cast<OpResult>(getInputChannel()),
294 makeI64Attrs(convolutionDims->inputChannel));
295 results.setParams(cast<OpResult>(getDepth()),
296 makeI64Attrs(convolutionDims->depth));
297
298 auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
299 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
300 return builder.getI64IntegerAttr(value);
301 });
302 };
303 results.setParams(cast<OpResult>(getStrides()),
304 makeI64AttrsFromI64(convolutionDims->strides));
305 results.setParams(cast<OpResult>(getDilations()),
306 makeI64AttrsFromI64(convolutionDims->dilations));
308}
309
310//===----------------------------------------------------------------------===//
311// Utilities for structured match predicates.
312//===----------------------------------------------------------------------===//
313
314/// Checks if all values from `list` are also contained in `reference`. Returns
315/// a silenceable error with the given message at the given location when it is
316/// not the case. The error message must contain the "{0}" placeholder that
317/// will be substituted with the value from `list` that is not contained in
318/// `reference`.
321 Location loc,
322 const char *message) {
323 for (int64_t value : list) {
324 if (llvm::any_of(reference, [&](unsigned ref) {
325 return static_cast<int64_t>(ref) == value;
326 })) {
327 continue;
328 }
329 return emitSilenceableFailure(loc) << llvm::formatv(message, value);
330 }
332}
333
334//===----------------------------------------------------------------------===//
335// MatchStructuredDimOp
336//===----------------------------------------------------------------------===//
337
338DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
339 Operation *current, transform::TransformResults &results,
341 auto linalgOp = cast<linalg::LinalgOp>(current);
342 SmallVector<int64_t> dimensions;
343 DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
344 if (!diag.succeeded())
345 return diag;
346
347 // If asked to check for the kind of dimension, perform the check.
348 if (getParallel() || getReduction()) {
349 SmallVector<unsigned> reference;
350 if (getParallel())
351 linalgOp.getParallelDims(reference);
352 else if (getReduction())
353 linalgOp.getReductionDims(reference);
354
356 containsAll(reference, dimensions, getLoc(),
357 getParallel() ? "expects dimension #{0} to be parallel"
358 : "expects dimension #{0} to be reduction");
359 if (!diag.succeeded())
360 return diag;
361 }
362
363 // If not capturing, we are done here.
364 if (!getResult())
365 return diag;
366
367 SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
368 Builder builder(current);
369 SmallVector<Attribute> captured =
370 llvm::map_to_vector(dimensions, [&](int64_t dim) -> Attribute {
371 return builder.getI64IntegerAttr(ranges[dim]);
372 });
373 results.setParams(cast<OpResult>(getResult()), captured);
375}
376
377DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
378 linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
380 expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
381 getRawDimList(), op.getNumLoops(), dims);
382 if (diag.isSilenceableFailure()) {
383 diag.attachNote(op->getLoc())
384 << "while considering dimensions of this payload operation";
385 }
386 return diag;
387}
388
389LogicalResult transform::MatchStructuredDimOp::verify() {
390 if (getParallel() && getReduction()) {
391 return emitOpError() << "cannot request the same dimension to be both "
392 "parallel and reduction";
393 }
394 return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
395 getIsInverted(), getIsAll());
396}
397
398//===----------------------------------------------------------------------===//
399// MatchStructuredElementalBitwidthOp
400//===----------------------------------------------------------------------===//
401
403transform::MatchStructuredElementalBitwidthOp::matchValue(
404 Value current, transform::TransformResults &results,
406 auto setupResult = [&](int64_t bitwidth) {
407 Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
408 results.setParams(cast<OpResult>(getResult()), {attr});
410 };
411
412 Type type = current.getType();
413 if (type.isIntOrFloat())
414 return setupResult(type.getIntOrFloatBitWidth());
415
416 if (auto shapedType = dyn_cast<ShapedType>(type)) {
417 if (shapedType.getElementType().isIntOrFloat())
418 return setupResult(shapedType.getElementTypeBitWidth());
419 }
420 return emitSilenceableError()
421 << "unsupported type for bitwidth extraction: " << type;
422}
423
424//===----------------------------------------------------------------------===//
425// MatchStructuredInputOp
426//===----------------------------------------------------------------------===//
427
428DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
429 Operation *current, transform::TransformResults &results,
431 auto linalgOp = cast<linalg::LinalgOp>(current);
432 SmallVector<int64_t> positions;
433 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
434 if (!diag.succeeded())
435 return diag;
436
437 SmallVector<MappedValue> operandMapping;
438 operandMapping.reserve(positions.size());
439 for (int64_t position : positions) {
440 AffineMap indexingMap =
441 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
442 if (getPermutation() && !indexingMap.isPermutation()) {
443 return emitSilenceableError() << "the indexing map for input #"
444 << position << " is not a permutation";
445 }
446 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
447 return emitSilenceableError()
448 << "the indexing map for input #" << position
449 << " is not a projected permutation";
450 }
451
452 // If capture not requested, skip it.
453 if (!getResult())
454 continue;
455
456 if (isa<AffineMapParamType>(getResult().getType())) {
457 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
458 continue;
459 }
460
461 Value operand = linalgOp.getDpsInputOperand(position)->get();
462 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
463 operandMapping.emplace_back(operand);
464 continue;
465 }
466
467 Operation *operandProducer = operand.getDefiningOp();
468 if (!operandProducer) {
469 return emitSilenceableError()
470 << "input #" << position << " is not produced by an operation";
471 }
472 operandMapping.emplace_back(operandProducer);
473 }
474 if (getResult())
475 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
477}
478
479DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
480 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
482 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
483 op.getNumDpsInputs(), positions);
484 if (diag.isSilenceableFailure()) {
485 diag.attachNote(op->getLoc())
486 << "while considering DPS inputs of this payload operation";
487 }
488 return diag;
489}
490
491/// Verifies a matcher op for structured input or output, specifically the
492/// attributes specifying the operand positions.
493template <typename OpTy>
494LogicalResult verifyStructuredOperandOp(OpTy op) {
495 if (op.getPermutation() && op.getProjectedPermutation()) {
496 return op.emitOpError()
497 << op.getPermutationAttrName() << " and "
498 << op.getProjectedPermutationAttrName() << " are mutually exclusive";
499 }
500 if (op.getRawPositionList().size() > 1 && op.getResult()) {
501 return op.emitOpError()
502 << "cannot bind multiple inputs/inits to the same value";
503 }
504
505 return success();
506}
507
508LogicalResult transform::MatchStructuredInputOp::verify() {
510 return failure();
511 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
512 getIsInverted(), getIsAll());
513}
514
515//===----------------------------------------------------------------------===//
516// MatchStructuredInitOp
517//===----------------------------------------------------------------------===//
518
519DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
520 Operation *current, transform::TransformResults &results,
522 auto linalgOp = cast<linalg::LinalgOp>(current);
523 SmallVector<int64_t> positions;
524 DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
525 if (!diag.succeeded())
526 return diag;
527
528 SmallVector<MappedValue> operandMapping;
529 operandMapping.reserve(positions.size());
530 for (int64_t position : positions) {
531 AffineMap indexingMap =
532 linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
533 if (getPermutation() && !indexingMap.isPermutation()) {
534 return emitSilenceableError() << "the indexing map for output(init) #"
535 << position << " is not a permutation";
536 }
537 if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
538 return emitSilenceableError() << "the indexing map for output(init) #"
539 << position << " is not a permutation";
540 }
541
542 // If capture not requested, skip it.
543 if (!getResult())
544 continue;
545
546 if (isa<AffineMapParamType>(getResult().getType())) {
547 operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
548 continue;
549 }
550
551 Value operand = linalgOp.getDpsInitOperand(position)->get();
552 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
553 operandMapping.emplace_back(operand);
554 continue;
555 }
556
557 Operation *operandProducer = operand.getDefiningOp();
558 if (!operandProducer) {
559 return emitSilenceableError() << "output(init) #" << position
560 << " is not produced by an operation";
561 }
562 operandMapping.emplace_back(operandProducer);
563 }
564 if (getResult())
565 results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
567}
568
569DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
570 linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
572 getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
573 op.getNumDpsInits(), positions);
574 if (diag.isSilenceableFailure()) {
575 diag.attachNote(op->getLoc())
576 << "while considering DPS inits (outputs) of this payload operation";
577 }
578 return diag;
579}
580
581LogicalResult transform::MatchStructuredInitOp::verify() {
583 return failure();
584 return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
585 getIsInverted(), getIsAll());
586}
587
588//===----------------------------------------------------------------------===//
589// MatchStructuredNumInputsOp
590//===----------------------------------------------------------------------===//
591
593transform::MatchStructuredNumInputsOp::matchOperation(
594 Operation *current, transform::TransformResults &results,
596 auto linalgOp = cast<linalg::LinalgOp>(current);
597 Attribute attr =
598 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
599 results.setParams(cast<OpResult>(getResult()), {attr});
601}
602
603//===----------------------------------------------------------------------===//
604// MatchStructuredNumInitsOp
605//===----------------------------------------------------------------------===//
606
608transform::MatchStructuredNumInitsOp::matchOperation(
609 Operation *current, transform::TransformResults &results,
611 auto linalgOp = cast<linalg::LinalgOp>(current);
612 Attribute attr =
613 Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
614 results.setParams(cast<OpResult>(getResult()), {attr});
616}
617
618//===----------------------------------------------------------------------===//
619// MatchStructuredRankOp
620//===----------------------------------------------------------------------===//
621
622DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
623 Operation *current, transform::TransformResults &results,
625 auto linalgOp = cast<linalg::LinalgOp>(current);
626 int64_t numLoops = linalgOp.getNumLoops();
627 Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
628 results.setParams(cast<OpResult>(getRank()), {attr});
630}
631
632//===----------------------------------------------------------------------===//
633// MatchStructuredResultOp
634//===----------------------------------------------------------------------===//
635
636DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
639 auto linalgOp = cast<linalg::LinalgOp>(op);
640 int64_t position;
641 DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
642 if (!diag.succeeded())
643 return diag;
644
645 Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
646 if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
647 results.setValues(cast<OpResult>(getResult()), {result});
649 }
650
651 if (result.getUsers().empty()) {
652 return emitSilenceableError()
653 << "no users of the result #" << getPosition();
654 }
655 Operation *firstUser = *result.getUsers().begin();
656 if (getAny()) {
657 results.set(cast<OpResult>(getResult()), {firstUser});
659 }
660 if (getSingle()) {
661 if (!llvm::hasSingleElement(result.getUsers())) {
662 return emitSilenceableError()
663 << "more than one result user with single user requested";
664 }
665 results.set(cast<OpResult>(getResult()), {firstUser});
667 }
668
669 return emitDefiniteFailure() << "unknown sub-predicate";
670}
671
673transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
674 int64_t &position) {
675 auto rawPosition = static_cast<int64_t>(getPosition());
676 position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
677 if (position >= op.getNumDpsInits() || position < 0) {
678 return emitSilenceableError()
679 << "position " << rawPosition
680 << " overflows the number of results(ints) of the payload operation";
681 }
683}
684
685LogicalResult transform::MatchStructuredResultOp::verify() {
686 if ((getAny() || getSingle()) ^
687 isa<TransformHandleTypeInterface>(getResult().getType())) {
688 return emitOpError() << "expects either the any/single keyword or the type "
689 "value handle result type";
690 }
691 if (getAny() && getSingle()) {
692 return emitOpError() << "'any' and 'single' are mutually exclusive";
693 }
694 return success();
695}
696
697//===----------------------------------------------------------------------===//
698// MatchStructuredYieldOp
699//===----------------------------------------------------------------------===//
700
701void transform::MatchStructuredYieldOp::getEffects(
703 onlyReadsHandle(getHandlesMutable(), effects);
704 onlyReadsPayload(effects);
705}
706
707void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
708 OperationState &state) {
709 build(builder, state, ValueRange());
710}
711
712#define GET_OP_CLASSES
713#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:257
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
uint64_t getN(LevelType lt)
Definition Enums.h:442
uint64_t getM(LevelType lt)
Definition Enums.h:443
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
This represents an operation in an abstracted form, suitable for use with the builder APIs.