MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
39
40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
52#include <cassert>
53#include <optional>
54
55using namespace mlir;
56using namespace mlir::linalg;
57
58/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
60 int64_t dim) {
61 auto type = cast<ShapedType>(v.getType());
62 if (!type.isDynamicDim(dim))
63 return builder.getIndexAttr(type.getDimSize(dim));
64
65 return getAsOpFoldResult(
67 .Case([&](RankedTensorType t) -> Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
69 })
70 .Case([&](MemRefType t) -> Value {
71 return memref::DimOp::create(builder, loc, v, dim);
72 }));
73}
74
75/// Returns a memref.subview or a tensor.extract_slice based on the type of the
76/// `source`.
80 ArrayRef<OpFoldResult> strides) {
82 .Case([&](RankedTensorType t) -> Operation * {
83 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
84 strides);
85 })
86 .Case([&](MemRefType type) -> Operation * {
87 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
88 strides);
89 })
90 .Default([&](Type t) -> Operation * { return nullptr; });
91}
92
93//===----------------------------------------------------------------------===//
94// Helper functions
95//===----------------------------------------------------------------------===//
96
98 int64_t dim) {
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable("Expected MemRefType or TensorType");
104}
105
107 int64_t dim) {
108 auto shapedType = llvm::cast<ShapedType>(source.getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110 return createOrFoldDimOp(b, loc, source, dim);
111 return b.getIndexAttr(shapedType.getDimSize(dim));
112}
113
114//===----------------------------------------------------------------------===//
115// Support for named Linalg ops defined in ods-gen.
116//===----------------------------------------------------------------------===//
117
121
122/// Fills the region of a structured operation using the provided
123/// `regionBuilder`. The method is used by both named structured ops created by
124/// ods-gen and by manually defined C++ ops. It is called by both builders and
125/// parsers and creates a block with arguments corresponding to the elemental
126/// types of `inputTypes` and `outputTypes`.
127static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
128 TypeRange inputTypes, TypeRange outputTypes,
131 RegionBuilderFn regionBuilder) {
132 SmallVector<Type, 8> argTypes;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(opBuilder.getUnknownLoc());
141 }
142 }
143
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs, emitError);
152
153 // indexing_maps is an auto-generated method.
154
155 // iterator_types is an auto-generated method.
156}
157
158/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159/// The result types are derived automatically if `resultTensorTypes` is none.
160/// The body of the operation is filled using `regionBuilder`. All ods-gen
161/// created structured operations use the method to implement their builders.
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
173
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
177
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
183
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187 state.attributes.getAttrs(), /*emitError=*/{},
188 regionBuilder);
189}
190
192 std::optional<TypeRange> resultTensorTypes,
193 ValueRange inputs, ValueRange outputs,
194 ArrayRef<NamedAttribute> attributes,
195 RegionBuilderFn regionBuilder,
196 ArrayRef<AffineMap> defaultIndexingMaps) {
197 // If indexing maps are not provided, apply the default ones.
198 if (none_of(attributes, [](NamedAttribute attr) {
199 return attr.getName() == "indexing_maps";
200 })) {
201 SmallVector<Attribute, 3> indexingMapsAttrVal;
202 indexingMapsAttrVal = llvm::map_to_vector(
203 defaultIndexingMaps,
204 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
205 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
206 }
207 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
208 attributes, regionBuilder);
209}
210
212 std::optional<TypeRange> resultTensorTypes,
213 ValueRange inputs, ValueRange outputs,
214 ArrayRef<NamedAttribute> attributes,
215 RegionBuilderFn regionBuilder,
216 ArrayRef<AffineMap> defaultIndexingMaps) {
217 // If indexing maps are not provided, apply the default ones.
218 if (none_of(attributes, [](NamedAttribute attr) {
219 return attr.getName() == "indexing_maps";
220 })) {
221 SmallVector<Attribute, 4> indexingMapsAttrVal;
222 indexingMapsAttrVal = llvm::map_to_vector(
223 defaultIndexingMaps,
224 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
225 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
226 }
227 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
228 attributes, regionBuilder);
229}
230
232 std::optional<TypeRange> resultTensorTypes,
233 ValueRange inputs, ValueRange outputs,
234 ArrayRef<NamedAttribute> attributes,
235 RegionBuilderFn regionBuilder,
236 ArrayRef<AffineMap> indexingMaps) {
237 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
238 SmallVector<Attribute, 4> indexingMapsAttrVal;
239 indexingMapsAttrVal =
240 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
241 return AffineMapAttr::get(map);
242 });
243 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
244 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
245 attributes, regionBuilder);
246}
247
248/// Common parsing used for both named structured ops created by ods-gen and by
249/// manually defined C++ ops. Does not handle regions.
250static ParseResult
252 SmallVectorImpl<Type> &inputTypes,
253 SmallVectorImpl<Type> &outputTypes,
254 bool addOperandSegmentSizes = true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
257 outputsOperands;
258
259 if (succeeded(parser.parseOptionalLess())) {
260 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
261 return failure();
262 }
263 attrsLoc = parser.getCurrentLocation();
264 if (parser.parseOptionalAttrDict(result.attributes))
265 return failure();
266
267 if (succeeded(parser.parseOptionalKeyword("ins"))) {
268 if (parser.parseLParen())
269 return failure();
270
271 inputsOperandsLoc = parser.getCurrentLocation();
272 if (parser.parseOperandList(inputsOperands) ||
273 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
274 return failure();
275 }
276
277 if (succeeded(parser.parseOptionalKeyword("outs"))) {
278 outputsOperandsLoc = parser.getCurrentLocation();
279 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
280 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
281 return failure();
282 }
283
284 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
285 result.operands) ||
286 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
287 result.operands))
288 return failure();
289
290 if (addOperandSegmentSizes) {
291 // This is a bit complex because we're trying to be backward compatible with
292 // operation syntax that mix the inherent attributes and the discardable
293 // ones in the same dictionary. If the properties are used, we append the
294 // operandSegmentSizes there directly. Otherwise we append it to the
295 // discardable attributes dictionary where it is handled by the generic
296 // Operation::create(...) method.
297 if (result.propertiesAttr) {
298 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
299 attrs.append("operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
303 result.propertiesAttr = attrs.getDictionary(parser.getContext());
304 } else {
305 result.addAttribute("operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
309 }
310 }
311 if (!result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
314 if (info) {
315 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 << "'" << result.name.getStringRef() << "' op ";
318 })))
319 return failure();
320 }
321 }
322 return success();
323}
324
326 ValueRange outputs) {
327 if (!inputs.empty())
328 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
329 if (!outputs.empty())
330 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
331}
332
333//===----------------------------------------------------------------------===//
334// Specific parsing and printing for named structured ops created by ods-gen.
335//===----------------------------------------------------------------------===//
336
338 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
339 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
340 RegionBuilderFn regionBuilder, SMLoc loc) {
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
342 return parser.emitError(
343 parser.getCurrentLocation(),
344 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
347 }
348
349 OpBuilder opBuilder(parser.getContext());
350 ParseResult result = success();
352 opBuilder, region, inputTypes, outputTypes, attrs,
353 [&]() {
354 result = failure();
355 return parser.emitError(loc);
356 },
357 regionBuilder);
358 return result;
359}
360
361static ParseResult
363 SmallVectorImpl<Type> &resultTypes) {
364 if (parser.parseOptionalArrowTypeList(resultTypes))
365 return failure();
366 return success();
367}
368
369static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
371 unsigned numRegionArgs,
372 RegionBuilderFn regionBuilder) {
373 // TODO: Enable when ods-gen supports captures.
374 SmallVector<Type, 1> inputTypes, outputTypes;
375 SMLoc loc = parser.getCurrentLocation();
376 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
377 return failure();
378
379 // Parse optional attributes.
380 if (parser.parseOptionalAttrDict(result.attributes))
381 return failure();
382
383 // TODO: consider merging results parsing into region parsing.
384 // Need to wait for declarative assembly resolution to decide.
385 SmallVector<Type, 1> outputTensorsTypes;
386 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
387 return failure();
388 result.addTypes(outputTensorsTypes);
389
390 std::unique_ptr<Region> region = std::make_unique<Region>();
391 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
392 outputTypes, result.attributes.getAttrs(),
393 regionBuilder, loc))
394 return failure();
395 result.addRegion(std::move(region));
396
397 return success();
398}
399
401 TypeRange resultTypes) {
402 if (resultTypes.empty())
403 return;
404 p.printOptionalArrowTypeList(resultTypes);
405}
406
408 ValueRange inputs, ValueRange outputs,
409 ArrayRef<StringRef> elidedAttrs = {}) {
410 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
411
412 // Printing is shared with generic ops, except for the region and
413 // attributes.
414 printCommonStructuredOpParts(p, inputs, outputs);
415
416 // Results printing.
418
419 // Region is elided.
420}
421
422//===----------------------------------------------------------------------===//
423// Region builder helper.
424// TODO: Move this to a utility library.
425// The public methods on this class are referenced directly from generated code.
426// Helper build the unary, binary, and type conversion functions defined by the
427// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
428// class.
429//
430// Implementations of the math functions must be polymorphic over numeric types,
431// internally performing necessary casts. If the function application makes no
432// sense, then the only recourse is to assert and return nullptr. This can be
433// extended later if it becomes possible to fail construction of the region. The
434// invariant should be enforced at a higher level.
435//
436// TODO: These helpers are currently type polymorphic over the class of integer
437// and floating point types, but they will not internally cast within bit
438// widths of a class (mixed precision such as i8->i32) or across classes
439// (i.e. mixed float and integer). Many such combinations are ambiguous or need
440// to be handled with care and work is being considered to extend the op
441// language to make such cases explicit. In the mean-time, violating this will
442// fail verification, which is deemed acceptable.
443//===----------------------------------------------------------------------===//
444
445namespace {
446
447class RegionBuilderHelper {
448public:
449 RegionBuilderHelper(OpBuilder &builder, Block &block)
450 : builder(builder), block(block) {}
451
452 // Build the unary functions defined by OpDSL.
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
454 function_ref<InFlightDiagnostic()> emitError = {}) {
455 if (!isFloatingPoint(arg)) {
456 if (emitError) {
457 emitError() << "unsupported non numeric type";
458 return nullptr;
459 }
460 llvm_unreachable("unsupported non numeric type");
461 }
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
464 switch (unaryFn) {
465 case UnaryFn::exp:
466 return math::ExpOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::log:
468 return math::LogOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::abs:
470 return math::AbsFOp::create(builder, arg.getLoc(), arg);
471 case UnaryFn::ceil:
472 return math::CeilOp::create(builder, arg.getLoc(), arg);
473 case UnaryFn::floor:
474 return math::FloorOp::create(builder, arg.getLoc(), arg);
475 case UnaryFn::negf:
476 return arith::NegFOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.getType());
479 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
482 }
483 case UnaryFn::round:
484 return math::RoundOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::sqrt:
486 return math::SqrtOp::create(builder, arg.getLoc(), arg);
487 case UnaryFn::rsqrt:
488 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
491 case UnaryFn::tanh:
492 return math::TanhOp::create(builder, arg.getLoc(), arg);
493 case UnaryFn::erf:
494 return math::ErfOp::create(builder, arg.getLoc(), arg);
495 }
496 if (emitError) {
497 emitError() << "unsupported unary function";
498 return nullptr;
499 }
500 llvm_unreachable("unsupported unary function");
501 }
502
503 // Build the binary functions defined by OpDSL.
504 // If emitError is provided, an error will be emitted if the operation is not
505 // supported and a nullptr will be returned, otherwise an assertion will be
506 // raised.
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
508 function_ref<InFlightDiagnostic()> emitError = {}) {
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
512 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
513 arg1.getType().getIntOrFloatBitWidth() == 1;
514 if (!allComplex && !allFloatingPoint && !allInteger) {
515 if (emitError) {
516 emitError()
517 << "Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
519 << arg0.getType() << " and " << arg1.getType();
520 return nullptr;
521 }
522 llvm_unreachable("unsupported non numeric type");
523 }
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
526 switch (binaryFn) {
527 case BinaryFn::add:
528 if (allComplex)
529 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool)
533 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
535 case BinaryFn::sub:
536 if (allComplex)
537 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
540 if (allBool) {
541 if (emitError) {
542 emitError() << "unsupported operation: sub with bools";
543 return nullptr;
544 }
545 llvm_unreachable("unsupported operation: sub with bools");
546 }
547 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::mul:
549 if (allComplex)
550 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool)
554 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
556 case BinaryFn::div:
557 if (allComplex)
558 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
561 if (allBool) {
562 if (emitError) {
563 emitError() << "unsupported operation: div with bools";
564 return nullptr;
565 }
566 llvm_unreachable("unsupported operation: div with bools");
567 }
568 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
571 if (emitError) {
572 emitError() << "unsupported operation: unsigned div not on uint";
573 return nullptr;
574 }
575 llvm_unreachable("unsupported operation: unsigned div not on uint");
576 }
577 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
579 assert(!allComplex);
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
584 assert(!allComplex);
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
589 assert(!allComplex);
590 if (!allInteger || allBool) {
591 if (emitError) {
592 emitError() << "unsupported operation: unsigned max not on uint";
593 return nullptr;
594 }
595 llvm_unreachable("unsupported operation: unsigned max not on uint");
596 }
597 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
599 assert(!allComplex);
600 if (!allInteger || allBool) {
601 if (emitError) {
602 emitError() << "unsupported operation: unsigned min not on uint";
603 return nullptr;
604 }
605 llvm_unreachable("unsupported operation: unsigned min not on uint");
606 }
607 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
608 case BinaryFn::powf:
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
611 }
612 if (emitError) {
613 emitError() << "unsupported binary function";
614 return nullptr;
615 }
616 llvm_unreachable("unsupported binary function");
617 }
618
619 // Build the ternary functions defined by OpDSL.
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
621 function_ref<InFlightDiagnostic()> emitError = {}) {
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
624 switch (ternaryFn) {
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
627 }
628 if (emitError) {
629 emitError() << "unsupported ternary function";
630 return nullptr;
631 }
632 llvm_unreachable("unsupported ternary function");
633 }
634
635 // Build the type functions defined by OpDSL.
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
637 function_ref<InFlightDiagnostic()> emitError = {}) {
638 switch (typeFn) {
639 case TypeFn::cast_signed:
640 return cast(toType, operand, false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand, true);
643 }
644 if (emitError) {
645 emitError() << "unsupported type conversion function";
646 return nullptr;
647 }
648 llvm_unreachable("unsupported type conversion function");
649 }
650
651 void yieldOutputs(ValueRange values) {
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
656 }
657
658 Value constant(const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr = parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
665 }
666
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
671 }
672
673 Type getIntegerType(unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
675 }
676
677 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
679
680private:
681 // Generates operations to cast the given operand to a specified type.
682 // If the cast cannot be performed, a warning will be issued and the
683 // operand returned as-is (which will presumably yield a verification
684 // issue downstream).
685 Value cast(Type toType, Value operand, bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.getLoc();
689 if (isa<UnknownLoc>(loc)) {
690 if (operand.getDefiningOp())
691 loc = operand.getDefiningOp()->getLoc();
692 else if (operand.getParentBlock() &&
693 operand.getParentBlock()->getParentOp())
694 loc = operand.getParentBlock()->getParentOp()->getLoc();
695 }
696 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
697 }
698
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.getType());
701 }
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.getType());
704 }
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.getType());
707 }
708
709 OpBuilder &builder;
710 Block &block;
711};
712
713} // namespace
714
715//===----------------------------------------------------------------------===//
716// CopyOp
717//===----------------------------------------------------------------------===//
718
719namespace {
720
721struct EraseSelfCopy : OpRewritePattern<CopyOp> {
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter) const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
726 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
727 if (copyOp.hasPureBufferSemantics())
728 rewriter.eraseOp(copyOp);
729 else
730 rewriter.replaceOp(copyOp, copyOp.getInputs());
731
732 return success();
733 }
734};
735
736} // namespace
737
738void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.add<EraseSelfCopy>(context);
741}
742
743//===----------------------------------------------------------------------===//
744// FillOp
745//===----------------------------------------------------------------------===//
746
747namespace {
748
749/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
750///
751/// For such op chains, we can create new linalg.fill ops with the result
752/// type of the tensor.expand/collapse_shape op.
753template <typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter) const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
759 if (!oldFill)
760 return failure();
761
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
765
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
770 } else {
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
774 }
775 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
776 ValueRange{newInit});
777 return success();
778 }
779};
780
781/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
782/// filling value are the same.
783struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
785
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter) const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
789 if (!fillOp)
790 return failure();
791
792 // We can only fold if the padding value is the same as the original
793 // filling value.
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
796 return failure();
797
798 ReifiedRankedShapedTypeDims reifiedShape;
799 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
800 return rewriter.notifyMatchFailure(
801 padOp, "failed to reify tensor.pad op result shape");
802
803 auto emptyTensor =
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
806 Value replacement =
807 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
808 ValueRange{emptyTensor})
809 .getResult(0);
810 if (replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
812 padOp.getResultType(), replacement);
813 }
814 rewriter.replaceOp(padOp, replacement);
815 return success();
816 }
817};
818
819/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
820/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
821/// filling value are the same.
822struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
824
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter) const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
828 if (!srcPadOp)
829 return failure();
830
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
832 return failure();
833
834 // Walk back the tensor.insert_slice chain and find the first destination
835 // value at the start of the chain.
836 Value firstDest = insertOp.getDest();
837 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
839 return failure();
840
841 // Make sure the range of values accessed are disjoint. Without this, we
842 // cannot fold tensor.pad away.
843 bool disjoint = false;
844 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
845 // If the dimension has dynamic offset/size, we cannot guarantee
846 // disjoint. So just skip it.
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
850 continue;
851
852 // Get the range start and end, inclusively for both.
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
860 disjoint = true;
861 break;
862 }
863 }
864
865 if (!disjoint)
866 break;
867 firstDest = prevOp.getDest();
868 }
869
870 // Check whether the first destination is a fill op. For overlapped cases,
871 // this also cannot be true.
872 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
873 if (!dstFillOp)
874 return failure();
875
876 // We can only fold if the padding value is the same as the original
877 // filling value.
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
880 return failure();
881
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
884
885 Location loc = insertOp.getLoc();
886 MLIRContext *context = getContext();
887
888 AffineExpr sym0, sym1;
889 bindSymbols(context, sym0, sym1);
890 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
891
892 // Calculate the new offsets for the insert. It should be the old offsets
893 // plus low padding sizes.
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
898 }
899
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
904 newSizes.push_back(
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
906 .getResult());
907 } else {
908 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
909 }
910 }
911
912 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
915 return success();
916 }
917};
918
919/// Fold tensor.extract(linalg.fill(<input>)) into <input>
920struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
921public:
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
923
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter) const override {
926 // See if tensor input of tensor.extract op is the result of a linalg.fill
927 // op.
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
929 if (!fillOp)
930 return failure();
931
932 // Get scalar input operand of linalg.fill op.
933 Value extractedScalar = fillOp.getInputs()[0];
934
935 // Replace tensor.extract op with scalar value used to fill the tensor.
936 rewriter.replaceOp(extractOp, extractedScalar);
937 return success();
938 }
939};
940
941/// Folds pack(fill) into a single fill op if
942/// 1. The pack op does not have padding value, or
943/// 2. The filled value and padding value are the same.
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
947 if (!fillOp)
948 return failure();
949
950 if (auto paddingValue = packOp.getPaddingValue())
951 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
952 return failure();
953
954 Value packOpDest = packOp.getDest();
955 if (!packOpDest.hasOneUse())
956 return failure();
957
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
959 packOp.getDest());
960}
961
962/// Wrapper pattern that applies foldFillPackIntoFillOp method.
963struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
964public:
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
967
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter) const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
971 if (failed(fillOp))
972 return failure();
973 rewriter.replaceOp(packOp, fillOp.value().result());
974 return success();
975 }
976};
977
978/// Fold fill with copy.
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
981
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter) const override {
984 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
985 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
986 fillOp.getInputs(),
987 copyOp.getOutputs());
988 return success();
989 }
990 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
991 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
992 fillOp.getOutputs());
993 return success();
994 }
995 return failure();
996 }
997};
998
999/// Fold fill with transpose.
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter) const override {
1005 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1006 rewriter.replaceOpWithNewOp<FillOp>(
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1009 return success();
1010 }
1011 return failure();
1012 }
1013};
1014
1015/// Fold a concat with all elements being fills of the same value
1016/// into a fill of the concat result shape.
1017struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1019
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter) const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1024 return failure();
1025 }
1026
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1028 if (!firstFillOp) {
1029 return failure();
1030 }
1031 // Prefetch the fill value.
1032 OpFoldResult firstFillVal =
1033 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1034 // Collect all the outs values for the fill operations.
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1037
1038 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1040 if (!fillOp) {
1041 return false;
1042 }
1043
1044 OpFoldResult fillVal =
1045 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1046 if (fillVal != firstFillVal)
1047 return false;
1048
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1050 return true;
1051 };
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1054 return rewriter.notifyMatchFailure(
1055 concatOp, "not all operands are defined by a compatible fill op");
1056 }
1057
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1060 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1061 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1062 return success();
1063 }
1064};
1065
1066} // namespace
1067
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// GenericOp
1079//===----------------------------------------------------------------------===//
1080
1082 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1083 ValueRange outputs,
1084 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1085 SmallVector<Type, 4> blockArgTypes;
1086 SmallVector<Location, 4> blockArgLocs;
1087 for (ValueRange container : {inputs, outputs}) {
1088 for (Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1091 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1092 blockArgLocs.push_back(v.getLoc());
1093 }
1094 }
1095
1096 OpBuilder::InsertionGuard guard(builder);
1097 Block *bodyBlock =
1098 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1099 bodyBuild(builder, loc, bodyBlock->getArguments());
1100}
1101
1102void GenericOp::getAsmBlockArgumentNames(Region &region,
1103 OpAsmSetValueNameFn setNameFn) {
1104 for (Value v : getRegionInputArgs())
1105 setNameFn(v, "in");
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v, "out");
1108}
1109
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1112 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1114 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1119 if (bodyBuild)
1120 buildGenericRegion(builder, result.location, *result.regions.front(),
1121 inputs, outputs, bodyBuild);
1122}
1123
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1126 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1129 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder, result, resultTensorTypes, inputs, outputs,
1132 builder.getAffineMapArrayAttr(indexingMaps),
1133 builder.getArrayAttr(llvm::map_to_vector(
1134 iteratorTypes,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1137 })),
1138 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1139 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1148 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1149 ArrayRef<NamedAttribute> attributes) {
1150 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, ValueRange inputs,
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1161 /*doc=*/"",
1162 /*libraryCall=*/"", bodyBuild, attributes);
1163}
1164
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1167 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1169 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1172 iteratorTypes,
1173 /*doc=*/"",
1174 /*libraryCall=*/"", bodyBuild, attributes);
1175}
1176
1177void GenericOp::print(OpAsmPrinter &p) {
1178 p << " ";
1179
1180 // Print extra attributes.
1181 auto genericAttrNames = linalgTraitAttrNames();
1182
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1191 // Convert IteratorType enums into the string representation. This is
1192 // needed, because tests still use the old format when 'iterator_types'
1193 // attribute is represented as an array of strings.
1194 // TODO: Remove this conversion once tests are fixed.
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(getContext(), stringifyIteratorType(t));
1198 });
1199
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(getContext(), iteratorTypeNames));
1203 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1205 }
1206 }
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1209 p << genericDictAttr;
1210 }
1211
1212 // Printing is shared with named ops, except for the region and attributes
1213 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1214
1215 genericAttrNames.push_back("operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1217
1218 bool hasExtraAttrs = false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1221 break;
1222 }
1223 if (hasExtraAttrs) {
1224 p << " attrs = ";
1225 p.printOptionalAttrDict((*this)->getAttrs(),
1226 /*elidedAttrs=*/genericAttrNames);
1227 }
1228
1229 // Print region.
1230 if (!getRegion().empty()) {
1231 p << ' ';
1232 p.printRegion(getRegion());
1233 }
1234
1235 // Print results.
1236 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1237}
1238
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1240 DictionaryAttr dictAttr;
1241 // Parse the core linalg traits that must check into a dictAttr.
1242 // The name is unimportant as we will overwrite result.attributes.
1243 // The core linalg traits must contain the information necessary to pass the
1244 // verifier.
1245 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1246 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1247 return failure();
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1250
1251 // Convert array of string into an array of IteratorType enums. This is
1252 // needed, because tests still use the old format when 'iterator_types'
1253 // attribute is represented as an array of strings.
1254 // TODO: Remove this conversion once tests are fixed.
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(result.name)));
1257 if (!iteratorTypes) {
1258 return parser.emitError(attributeLocation)
1259 << "expected " << getIteratorTypesAttrName(result.name)
1260 << " array attribute";
1261 }
1262
1263 SmallVector<Attribute> iteratorTypeAttrs;
1264
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1268 return parser.emitError(parser.getCurrentLocation())
1269 << "unexpected iterator_type (" << s << ")";
1270
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1273 }
1274 result.attributes.set(getIteratorTypesAttrName(result.name),
1275 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1276
1277 // Parsing is shared with named ops, except for the region.
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1279 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1280 return failure();
1281
1282 // Optional attributes may be added.
1283 if (succeeded(parser.parseOptionalKeyword("attrs")))
1284 if (failed(parser.parseEqual()) ||
1285 failed(parser.parseOptionalAttrDict(result.attributes)))
1286 return failure();
1287
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 if (parser.parseRegion(*region, {}))
1290 return failure();
1291 result.addRegion(std::move(region));
1292
1293 // Generic ops may specify that a subset of its outputs are tensors. Such
1294 // outputs are specified in the result type.
1295 // TODO: may need to move output parsing before region parsing.
1296 // Need to wait for declarative assembly resolution to decide.
1297 SmallVector<Type, 1> outputTensorsTypes;
1298 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1299 return failure();
1300 result.addTypes(outputTensorsTypes);
1301
1302 return success();
1303}
1304
1307 &effects,
1308 LinalgOp linalgOp) {
1309 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.getType()))
1311 continue;
1312 effects.emplace_back(
1313 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1314 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1315 }
1316
1317 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().getType()))
1319 continue;
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1322 /*effectOnFullRegion=*/true,
1324 }
1325 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1326 /*effectOnFullRegion=*/true,
1328 }
1329}
1330
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1333 &effects) {
1334 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1335}
1336
1339 // Operands with value semantics are speculatable, while operands with memory
1340 // semantics are not.
1341 if (!linalgOp.hasPureTensorSemantics())
1343 // The body of the op can still have speculation in its region.
1345}
1346
1347Speculation::Speculatability GenericOp::getSpeculatability() {
1348 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1349}
1350
1351namespace {
1352
1353/// Remove linalg operations that are just copying the values from inputs to
1354/// results. In the memref case, the operation must be copying to and from the
1355/// same value. Requirements are:
1356/// 1) All iterator types are parallel
1357/// 2) The body contains just a yield operation with the yielded values being
1358/// the arguments corresponding to the operands.
1359template <typename OpTy>
1360struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1362
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter) const override {
1365 // All indexing maps must be equal. It follows that they are permutations.
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1367 return failure();
1368
1369 // Check that the body of the linalg operation is just a linalg.yield
1370 // operation.
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1373 return failure();
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1375 if (!yieldOp)
1376 return failure();
1377
1378 // In the buffer case, we need to check exact buffer equality.
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1383 return rewriter.notifyMatchFailure(
1384 linalgOp, "expected single input and output to be the same value");
1385 }
1386
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1389 return rewriter.notifyMatchFailure(linalgOp,
1390 "cannot fold fill-like op");
1391 }
1392
1393 rewriter.eraseOp(linalgOp);
1394 return success();
1395 }
1396
1397 if (!linalgOp.hasPureTensorSemantics()) {
1398 return rewriter.notifyMatchFailure(
1399 linalgOp, "mixed semantics is not supported yet");
1400 }
1401
1402 // Get the argument number of the returned values. That is the operand
1403 // number to use for replacing uses of this operation.
1404 SmallVector<Value> returnedArgs;
1405 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1408 return failure();
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1412 // The input can have a different type than the result, e.g. a dynamic
1413 // input dimension can be turned into a static output dimension.
1414 Type returnType = returnedArg.getType();
1415 if (returnType != resultType) {
1416 // Distinguish between sparse conversion or dense tensor casting.
1417 // TODO: unify the two ops?
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1422 else {
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1424 resultType))
1425 return failure();
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1428 }
1429 }
1430 returnedArgs.push_back(returnedArg);
1431 }
1432
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1434 return failure();
1435 rewriter.replaceOp(linalgOp, returnedArgs);
1436 return success();
1437 }
1438};
1439
1440} // namespace
1441
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1445}
1446
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1448 return memref::foldMemRefCast(*this);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// MapOp
1453//===----------------------------------------------------------------------===//
1454
1455static ParseResult parseDstStyleOp(
1457 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1458 nullptr) {
1459 // Parse `ins` and `outs`.
1460 SmallVector<Type, 4> inputTypes, outputTypes;
1461 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1462 /*addOperandSegmentSizes=*/false))
1463 return failure();
1464
1465 // Add result types.
1466 for (Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1469 }
1470
1471 // Parse required attributes.
1472 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1473 return failure();
1474
1475 // Parse optional attributes.
1476 if (parser.parseOptionalAttrDict(result.attributes))
1477 return failure();
1478 return success();
1479}
1480
1481void MapOp::getAsmBlockArgumentNames(Region &region,
1482 OpAsmSetValueNameFn setNameFn) {
1483 for (Value v : getRegionInputArgs())
1484 setNameFn(v, "in");
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v, "init");
1487}
1488
1489void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(), "mapped");
1492}
1493
1494void MapOp::build(
1495 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1496 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1497 ArrayRef<NamedAttribute> attributes) {
1498 build(builder, result, TypeRange{}, inputs, init);
1499 result.addAttributes(attributes);
1500
1501 // Add output types for `RankedTensorType` output arguments.
1502 Type initType = init.getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1505
1506 if (bodyBuild)
1507 buildGenericRegion(builder, result.location, *result.regions.front(),
1508 inputs, /*outputs=*/{init}, bodyBuild);
1509}
1510
1512 const OperationName &payloadOpName,
1513 const NamedAttrList &payloadOpAttrs,
1514 ArrayRef<Value> operands,
1515 bool initFirst = false, bool mapInit = true) {
1516 OpBuilder b(parser.getContext());
1517 Region *body = result.addRegion();
1518 Block &block = body->emplaceBlock();
1519 b.setInsertionPointToStart(&block);
1520 for (auto &operand : operands) {
1521 block.addArgument(
1522 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1523 b.getUnknownLoc());
1524 }
1525 SmallVector<Value> payloadOpOperands;
1526 // If initFirst flag is enabled, we consider init as the first position of
1527 // payload operands.
1528 if (initFirst) {
1529 if (mapInit)
1530 payloadOpOperands.push_back(block.getArguments().back());
1531 for (const auto &arg : block.getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1533 } else {
1534 payloadOpOperands = {block.getArguments().begin(),
1535 block.getArguments().end() - int(!mapInit)};
1536 }
1537
1538 Operation *payloadOp = b.create(
1539 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1540 payloadOpOperands,
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1542 .getElementType()},
1543 payloadOpAttrs);
1544 YieldOp::create(b, result.location, payloadOp->getResults());
1545}
1546
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1550 if (succeeded(parser.parseOptionalLBrace())) {
1551 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1552 if (failed(operationName))
1553 return failure();
1554 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1555 return failure();
1556 payloadOpName = operationName.value();
1557 if (parser.parseRBrace())
1558 return failure();
1559 }
1560
1561 if (parseDstStyleOp(parser, result))
1562 return failure();
1563
1564 if (payloadOpName.has_value()) {
1565 if (!result.operands.empty())
1566 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1567 payloadOpAttrs, ArrayRef(result.operands), false,
1568 false);
1569 else
1570 result.addRegion();
1571 } else {
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1573 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1574 /*allowType=*/true, /*allowAttrs=*/true)) {
1575 return failure();
1576 }
1577 Region *body = result.addRegion();
1578 if (parser.parseRegion(*body, regionArgs))
1579 return failure();
1580 }
1581 return success();
1582}
1583
1584static bool canUseShortForm(Block *body, bool initFirst = false,
1585 bool mapInit = true) {
1586 // `intFirst == true` implies that we want to map init arg
1587 if (initFirst && !mapInit)
1588 return false;
1589 // Check if the body can be printed in short form. The following 4 conditions
1590 // must be satisfied:
1591
1592 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1593 if (body->getOperations().size() != 2)
1594 return false;
1595 Operation &payload = body->getOperations().front();
1596
1597 // 2) The payload op must have the same number of operands as the number of
1598 // block arguments.
1599 if (payload.getNumOperands() == 0 ||
1600 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1601 return false;
1602
1603 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1604 // must be the first operand of the payload op, otherwise, the operands
1605 // must match the block arguments in order.
1606 if (initFirst) {
1607 // check init
1608 if (payload.getOperands().back() != body->getArgument(0))
1609 return false;
1610 // check rest
1611 for (const auto &[operand, bbArg] :
1612 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1613 if (bbArg != operand)
1614 return false;
1615 }
1616 } else {
1617 for (const auto &[operand, bbArg] :
1618 llvm::zip(payload.getOperands(),
1619 body->getArguments().drop_back(int(!mapInit)))) {
1620 if (bbArg != operand)
1621 return false;
1622 }
1623 }
1624
1625 // 4) The `yield` operand must be the result of the payload op.
1626 auto yieldOp = cast<YieldOp>(body->getTerminator());
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1630}
1631
1632static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1633 SmallVector<StringRef> elidedAttrs;
1634 std::string attrToElide;
1635 p << " { " << payloadOp->getName().getStringRef();
1636 for (const auto &attr : payloadOp->getAttrs()) {
1637 auto fastAttr =
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1642 break;
1643 }
1644 }
1645 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1646 p << " }";
1647}
1648
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1651 bool useShortForm =
1652 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1653 if (useShortForm) {
1654 printShortForm(p, &mapper->getOperations().front());
1655 }
1656
1657 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1658 p.printOptionalAttrDict((*this)->getAttrs());
1659
1660 if (!useShortForm) {
1661 // Print region if the payload op was not detected.
1662 p.increaseIndent();
1663 p.printNewline();
1664 p << "(";
1665 llvm::interleaveComma(mapper->getArguments(), p,
1666 [&](auto arg) { p.printRegionArgument(arg); });
1667 p << ") ";
1668
1669 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1670 p.decreaseIndent();
1671 }
1672}
1673
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1677
1678 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1679 // region.
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() << "expects number of operands to match the arity of "
1682 "mapper, but got: "
1683 << getInputs().size() + 1 << " and "
1684 << blockArgs.size();
1685
1686 // The parameters of mapper should all match the element type of inputs.
1687 for (const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() << "expected element type of input " << inputElemType
1693 << " to match bbArg type " << bbArgType;
1694 }
1695 }
1696
1697 // The shape of each input must match the shape of the output.
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType : TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() << "expected shape of input (" << inputElemShape
1703 << ") to match shape of output (" << outputShape
1704 << ")";
1705 }
1706 }
1707
1708 return success();
1709}
1710
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1714}
1715
1716ArrayAttr MapOp::getIndexingMaps() {
1717 Builder builder(getContext());
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1720 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1721 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1722}
1723
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1726 &effects) {
1727 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1728}
1729
1730Speculation::Speculatability MapOp::getSpeculatability() {
1731 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// ReduceOp
1736//===----------------------------------------------------------------------===//
1737
1738void ReduceOp::getAsmBlockArgumentNames(Region &region,
1739 OpAsmSetValueNameFn setNameFn) {
1740 for (Value v : getRegionInputArgs())
1741 setNameFn(v, "in");
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v, "init");
1744}
1745
1746void ReduceOp::getAsmResultNames(
1747 function_ref<void(Value, StringRef)> setNameFn) {
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(), "reduced");
1750}
1751
1752void ReduceOp::build(
1753 OpBuilder &builder, OperationState &result, ValueRange inputs,
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1755 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1756 ArrayRef<NamedAttribute> attributes) {
1757 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1758 result.addAttributes(attributes);
1759
1760 // Add output types for `RankedTensorType` output arguments.
1761 for (Value init : inits) {
1762 Type initType = init.getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1765 }
1766
1767 if (bodyBuild)
1768 buildGenericRegion(builder, result.location, *result.regions.front(),
1769 inputs, inits, bodyBuild);
1770}
1771
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1773 int64_t inputRank =
1774 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1780}
1781
1782ArrayAttr ReduceOp::getIndexingMaps() {
1783 int64_t inputRank =
1784 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1786 getNumDpsInputs(),
1788 AffineMap resultMap =
1790 .dropResults(getDimensions());
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1794}
1795
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1798 &effects) {
1799 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1800}
1801
1802Speculation::Speculatability ReduceOp::getSpeculatability() {
1803 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1804}
1805
1806static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1807 NamedAttrList &attributes,
1808 StringRef attributeName) {
1809 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1810 return failure();
1811
1812 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1813 return success();
1814}
1815
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1819 if (succeeded(parser.parseOptionalLBrace())) {
1820 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1821 if (failed(operationName))
1822 return failure();
1823 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1824 return failure();
1825 payloadOpName = operationName.value();
1826 if (parser.parseRBrace())
1827 return failure();
1828 }
1829
1830 if (parseDstStyleOp(
1831 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1832 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1833 }))
1834 return failure();
1835
1836 if (payloadOpName.has_value()) {
1837 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1838 ArrayRef(result.operands), /*initFirst=*/true);
1839 } else {
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1841 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1842 /*allowType=*/true, /*allowAttrs=*/true)) {
1843 return failure();
1844 }
1845
1846 Region *body = result.addRegion();
1847 if (parser.parseRegion(*body, regionArgs))
1848 return failure();
1849 }
1850
1851 return success();
1852}
1853
1854static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1855 ArrayRef<int64_t> attributeValue) {
1856 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1857}
1858
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1861 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1862 if (useShortForm) {
1863 printShortForm(p, &mapper->getOperations().front());
1864 }
1865
1866 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1867 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1868 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1869 if (!useShortForm) {
1870 // Print region if the payload op was not detected.
1871 p.increaseIndent();
1872 p.printNewline();
1873 p << "(";
1874 llvm::interleaveComma(mapper->getArguments(), p,
1875 [&](auto arg) { p.printRegionArgument(arg); });
1876 p << ") ";
1877
1878 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1879 p.decreaseIndent();
1880 }
1881}
1882
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1885
1886 // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
1887 // of inputs and inits. Detect a mismatch early: when they differ, the
1888 // ODS-generated getInputs()/getInits() accessors compute each group's size
1889 // via floordiv of the total operand count, producing incorrect slices that
1890 // would cause out-of-bounds accesses below.
1891 if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
1892 return emitOpError()
1893 << "expected equal number of inputs and outputs (required by "
1894 "SameVariadicOperandSize), got "
1895 << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
1896 << " output(s)";
1897
1898 if (getInputs().empty())
1899 return emitOpError() << "expected at least one input";
1900
1901 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1902 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1903 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1904 return emitOpError() << "expects all inputs to have the same shapes. "
1905 "Shape at input-index "
1906 << i
1907 << " is not equal to the shape at input-index 0.";
1908 }
1909 }
1910 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1911 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1912 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1913 return emitOpError() << "expects all outputs to have the same shapes. "
1914 "Shape at output-index "
1915 << i
1916 << " is not equal to the shape at output-index 0.";
1917 }
1918 }
1919 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1920 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1921
1922 DenseSet<int64_t> dimensionsToReduce;
1923 for (int64_t dimension : dimensionsRef) {
1924 if (dimension < 0 || dimension >= inputType.getRank()) {
1925 return emitOpError()
1926 << "dimensions for reduction should be in the range [0, "
1927 << inputType.getRank() - 1 << "].";
1928 }
1929 dimensionsToReduce.insert(dimension);
1930 }
1931
1932 auto inputDims = inputType.getShape();
1933 auto initDims = initType.getShape();
1934
1935 // Input dimensions that will be left after the reduction.
1936 SmallVector<int64_t> reducedInputDims;
1937 for (const auto &en : llvm::enumerate(inputDims)) {
1938 if (!dimensionsToReduce.count(en.index()))
1939 reducedInputDims.push_back(en.value());
1940 }
1941
1942 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1943 return emitOpError() << "number of dimensions after reduction "
1944 << reducedInputDims.size()
1945 << " doesn't match the init rank "
1946 << initType.getRank();
1947 }
1948
1949 if (reducedInputDims != initDims)
1950 return emitOpError() << "init dimensions [" << initDims
1951 << "] doesn't match input dimensions after reduction ["
1952 << reducedInputDims << "]";
1953
1954 Block *block = getBody();
1955 if (block->getNumArguments() != this->getNumOperands())
1956 return emitOpError()
1957 << "mismatching number of operands and block arguments";
1958
1959 // Check that the first block arguments match the element type of the inputs.
1960 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1961 Type inputElementType =
1962 llvm::cast<ShapedType>(input.getType()).getElementType();
1963 if (inputElementType != bbArg.getType())
1964 return emitOpError()
1965 << "input element type " << inputElementType
1966 << " does not match corresponding block argument type "
1967 << bbArg.getType();
1968 }
1969
1970 // Check that the last block arguments match the element type of the outputs.
1971 for (auto [output, bbArg] : llvm::zip(
1972 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1973 auto outputElementType =
1974 llvm::cast<ShapedType>(output.getType()).getElementType();
1975 if (outputElementType != bbArg.getType())
1976 return emitOpError()
1977 << "output element type " << outputElementType
1978 << " does not match corresponding block argument type "
1979 << bbArg.getType();
1980 }
1981 return success();
1982}
1983
1984//===----------------------------------------------------------------------===//
1985// TransposeOp
1986//===----------------------------------------------------------------------===//
1987
1988static void buildIdentityRegion(OpBuilder &builder, Location loc,
1989 Region &region, ValueRange inputs,
1990 ValueRange outputs) {
1991 buildGenericRegion(builder, loc, region, inputs, outputs,
1992 [](OpBuilder &b, Location loc, ValueRange args) {
1993 if (!args.empty())
1994 linalg::YieldOp::create(b, loc, args[0]);
1995 });
1996}
1997
1998void TransposeOp::build(::mlir::OpBuilder &builder,
1999 ::mlir::OperationState &result, Value input, Value init,
2000 DenseI64ArrayAttr permutation,
2001 ArrayRef<NamedAttribute> attributes) {
2002 result.addOperands(input);
2003 result.addOperands(init);
2004 result.addAttribute(getPermutationAttrName(result.name), permutation);
2005 result.addAttributes(attributes);
2006
2007 // Add output types for `RankedTensorType` output arguments.
2008 Type initType = init.getType();
2009 if (llvm::isa<RankedTensorType>(initType))
2010 result.addTypes(initType);
2011
2012 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2013 init);
2014}
2015
2016void TransposeOp::build(::mlir::OpBuilder &builder,
2017 ::mlir::OperationState &result, Value input, Value init,
2018 ArrayRef<int64_t> permutation,
2019 ArrayRef<NamedAttribute> attributes) {
2020 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2021 attributes);
2022}
2023
2024ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2026 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2027 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2028 })))
2029 return failure();
2030
2031 OpBuilder builder(parser.getContext());
2032 buildIdentityRegion(builder, result.location, *result.addRegion(),
2033 /*inputs=*/result.operands,
2034 /*outputs=*/{});
2035 return success();
2036}
2037
2038void TransposeOp::getAsmResultNames(
2039 function_ref<void(Value, StringRef)> setNameFn) {
2040 if (!getResults().empty())
2041 setNameFn(getResults().front(), "transposed");
2042}
2043
2044void TransposeOp::print(OpAsmPrinter &p) {
2045 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2046 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2047 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2048}
2049
2050LogicalResult TransposeOp::verify() {
2051 ArrayRef<int64_t> permutationRef = getPermutation();
2052
2053 if (!isPermutationVector(permutationRef))
2054 return emitOpError("permutation is not valid");
2055
2056 auto inputType = getInput().getType();
2057 auto initType = getInit().getType();
2058
2059 int64_t rank = inputType.getRank();
2060
2061 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2062 "init")))
2063 return failure();
2064
2065 if (rank != static_cast<int64_t>(permutationRef.size()))
2066 return emitOpError() << "size of permutation " << permutationRef.size()
2067 << " does not match the argument rank " << rank;
2068
2069 auto inputDims = inputType.getShape();
2070 auto initDims = initType.getShape();
2071
2072 for (int64_t i = 0; i < rank; ++i) {
2073 int64_t inputDim = inputDims[permutationRef[i]];
2074 int64_t initDim = initDims[i];
2075
2076 if (inputDim != initDim) {
2077 return emitOpError() << "dim(result, " << i << ") = " << initDim
2078 << " doesn't match dim(input, permutation[" << i
2079 << "]) = " << inputDim;
2080 }
2081 }
2082
2083 return success();
2084}
2085
2086SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2087 int64_t rank = getInit().getType().getRank();
2088 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2089}
2090
2091ArrayAttr TransposeOp::getIndexingMaps() {
2092 Builder builder(getContext());
2093 int64_t rank = getInit().getType().getRank();
2094 return builder.getAffineMapArrayAttr(
2096 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2097 builder.getMultiDimIdentityMap(rank)});
2098}
2099
2100void TransposeOp::getEffects(
2101 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2102 &effects) {
2103 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2104}
2105
2106Speculation::Speculatability TransposeOp::getSpeculatability() {
2107 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2108}
2109
2110LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2111 SmallVectorImpl<OpFoldResult> &result) {
2112 // Only the tensor type is supported.
2113 if (!isa<TensorType>(getInput().getType()))
2114 return failure();
2115
2116 // Single dimension transpose.
2117 if (getPermutation().empty()) {
2118 result.push_back(getInput());
2119 return success();
2120 }
2121 // Identity permutation.
2122 if (isIdentityPermutation(getPermutation())) {
2123 result.push_back(getInput());
2124 return success();
2125 }
2126
2127 return failure();
2128}
2129
2130/// Fold transpose with transpose.
2131struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2132 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2133
2134 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2135 PatternRewriter &rewriter) const override {
2136 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2137 if (!defTransposeOp)
2138 return failure();
2139 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2140 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2141 SmallVector<int64_t> foldedPerms;
2142 foldedPerms.reserve(perms.size());
2143 for (int64_t perm : perms)
2144 foldedPerms.push_back(defPerms[perm]);
2145
2146 rewriter.replaceOpWithNewOp<TransposeOp>(
2147 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2148 foldedPerms);
2149 return success();
2150 }
2151};
2152
2153/// This pattern canonicalize transpose by swapping the order of
2154/// broadcast and transpose:
2155/// transpose(broadcast(input)) -> broadcast(transpose(input))
2156struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2157 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2158
2159 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2160 PatternRewriter &rewriter) const override {
2161 Value input = transposeOp.getInput();
2162 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2163 if (!input.hasOneUse() || !broadcastOp)
2164 return failure();
2165
2166 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2167 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2168
2169 // Get new perms and new dimensions.
2170 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2172 SmallVector<int64_t> resultDimensions;
2173 unsigned dimensionSize = dimensions.size();
2174 for (unsigned i = 0; i < dimensionSize; ++i)
2175 resultDimensions.push_back(invertPerm[dimensions[i]]);
2176
2177 // Create transpose result.
2178 Value broadcastInput = broadcastOp.getInput();
2179 Location loc = transposeOp.getLoc();
2180 MLIRContext *ctx = transposeOp.getContext();
2182 auto broadcastInputTy =
2183 mlir::cast<RankedTensorType>(broadcastInput.getType());
2184 unsigned inputRank = broadcastInputTy.getRank();
2185 for (unsigned i = 0; i < inputRank; ++i) {
2186 if (broadcastInputTy.isDynamicDim(i)) {
2187 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2188 ->getResult(0));
2189 } else {
2190 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2191 broadcastInputTy.getDimSize(i)));
2192 }
2193 }
2194 SmallVector<OpFoldResult> transposeResultShapes =
2195 applyPermutation(dims, resultPerms);
2196 Value transposeInit = tensor::EmptyOp::create(
2197 rewriter, transposeOp.getLoc(), transposeResultShapes,
2198 broadcastInputTy.getElementType());
2199
2200 // Create broadcast(transpose(input)).
2201 Value transposeResult =
2202 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2203 transposeInit, resultPerms)
2204 ->getResult(0);
2205 rewriter.replaceOpWithNewOp<BroadcastOp>(
2206 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2207 return success();
2208 }
2209};
2210
2211void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2212 MLIRContext *context) {
2213 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2214}
2215
2216//===----------------------------------------------------------------------===//
2217// BroadcastOp
2218//===----------------------------------------------------------------------===//
2219
2220void BroadcastOp::build(::mlir::OpBuilder &builder,
2221 ::mlir::OperationState &result, Value input, Value init,
2222 DenseI64ArrayAttr dimensions,
2223 ArrayRef<NamedAttribute> attributes) {
2224 result.addOperands(input);
2225 result.addOperands(init);
2226 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2227 result.addAttributes(attributes);
2228
2229 // Add output types for `RankedTensorType` output arguments.
2230 Type initType = init.getType();
2231 if (llvm::isa<RankedTensorType>(initType))
2232 result.addTypes(initType);
2233
2234 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2235 init);
2236}
2237
2238void BroadcastOp::build(::mlir::OpBuilder &builder,
2239 ::mlir::OperationState &result, Value input, Value init,
2240 ArrayRef<int64_t> dimensions,
2241 ArrayRef<NamedAttribute> attributes) {
2242 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2243 attributes);
2244}
2245
2246ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2248 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2249 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2250 })))
2251 return failure();
2252
2253 OpBuilder builder(parser.getContext());
2254 buildIdentityRegion(builder, result.location, *result.addRegion(),
2255 /*inputs=*/result.operands,
2256 /*outputs=*/{});
2257 return success();
2258}
2259
2260void BroadcastOp::getAsmResultNames(
2261 function_ref<void(Value, StringRef)> setNameFn) {
2262 if (!getResults().empty())
2263 setNameFn(getResults().front(), "broadcasted");
2264}
2265
2266void BroadcastOp::print(OpAsmPrinter &p) {
2267 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2268 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2269 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2270}
2271
2272LogicalResult BroadcastOp::verify() {
2273 ArrayRef<int64_t> dimensionsRef = getDimensions();
2274
2275 auto inputType = getInput().getType();
2276 auto initType = getInit().getType();
2277
2278 int64_t inputRank = inputType.getRank();
2279 int64_t initRank = initType.getRank();
2280
2281 auto inputShape = inputType.getShape();
2282 auto initShape = initType.getShape();
2283
2284 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2285 return emitOpError() << "input rank plus added dimensions does not "
2286 "match init rank. input rank: "
2287 << inputRank
2288 << ", dimensions size: " << dimensionsRef.size()
2289 << ", init rank: " << initRank;
2290
2291 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2292 if (dim < 0 || dim >= initRank)
2293 return emitOpError() << "dimension " << idx
2294 << " is out of range. expected range: [0, "
2295 << initRank - 1 << "], got: " << dim;
2296 }
2297
2298 // Mapping from input dims to init dims.
2299 SmallVector<int64_t> dimMap;
2300 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2301 if (!llvm::is_contained(dimensionsRef, dim))
2302 dimMap.push_back(dim);
2303 }
2304
2305 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2306 // This dimensions is mapped from the input. Init and input dims should
2307 // match.
2308 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2309 return emitOpError() << "input dim " << inputDimIdx
2310 << " should match init dim " << initDimIdx
2311 << ". input: " << inputShape[inputDimIdx]
2312 << ", init: " << initShape[initDimIdx];
2313 }
2314
2315 return success();
2316}
2317
2318SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2319 int64_t rank = getInit().getType().getRank();
2320 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2321}
2322
2323ArrayAttr BroadcastOp::getIndexingMaps() {
2324 Builder builder(getContext());
2325 int64_t rank = getInit().getType().getRank();
2326 return builder.getAffineMapArrayAttr(
2327 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2328 builder.getMultiDimIdentityMap(rank)});
2329}
2330
2331void BroadcastOp::getEffects(
2332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2333 &effects) {
2334 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2335}
2336
2337Speculation::Speculatability BroadcastOp::getSpeculatability() {
2338 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2339}
2340
2341/// Fold back-to-back broadcasts together.
2342struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2343 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2344
2345 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2346 PatternRewriter &rewriter) const override {
2347 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2348 if (!defBroadcastOp)
2349 return failure();
2350 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2351 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2352 SmallVector<int64_t> foldedDims(dimensions);
2353 Value init = broadcastOp.getInit();
2354 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2355 // Mapping from input dims to init dims.
2356 SmallVector<int64_t> dimMap;
2357 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2358 if (!llvm::is_contained(dimensions, dim))
2359 dimMap.push_back(dim);
2360 }
2361 for (auto dim : defDimensions)
2362 foldedDims.push_back(dimMap[dim]);
2363
2364 llvm::sort(foldedDims);
2365 rewriter.replaceOpWithNewOp<BroadcastOp>(
2366 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2367 return success();
2368 }
2369};
2370
2371void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2372 MLIRContext *context) {
2373 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2374}
2375
2376//===----------------------------------------------------------------------===//
2377// YieldOp
2378//===----------------------------------------------------------------------===//
2379
2380void linalg::YieldOp::print(OpAsmPrinter &p) {
2381 if (getNumOperands() > 0)
2382 p << ' ' << getOperands();
2383 p.printOptionalAttrDict((*this)->getAttrs());
2384 if (getNumOperands() > 0)
2385 p << " : " << getOperandTypes();
2386}
2387
2388ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2389 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2390 SmallVector<Type, 2> types;
2391 SMLoc loc = parser.getCurrentLocation();
2392 return failure(parser.parseOperandList(opInfo) ||
2393 parser.parseOptionalAttrDict(result.attributes) ||
2394 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2395 parser.resolveOperands(opInfo, types, loc, result.operands));
2396}
2397
2398// Check the operand number and types must match the element types of the
2399// LinalgOp interface's shaped operands.
2400static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2401 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2402 return op.emitOpError("expected number of yield values (")
2403 << op.getNumOperands()
2404 << ") to match the number of inits / outs operands of the enclosing "
2405 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2406
2407 for (OpOperand &opOperand : op->getOpOperands()) {
2408 OpOperand *outputOperand =
2409 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2410 Type elementType = outputOperand->get().getType();
2411 if (isa<MemRefType, RankedTensorType>(elementType))
2412 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2413 if (opOperand.get().getType() != elementType)
2414 return op.emitOpError("type of yield operand ")
2415 << (opOperand.getOperandNumber() + 1) << " ("
2416 << opOperand.get().getType() << ") doesn't match "
2417 << "the element type of the enclosing linalg.generic op ("
2418 << elementType << ")";
2419 }
2420 return success();
2421}
2422
2423LogicalResult linalg::YieldOp::verify() {
2424 auto *parentOp = (*this)->getParentOp();
2425 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2426 return emitOpError("expected single non-empty parent region");
2427
2428 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2429 return verifyYield(*this, linalgOp);
2430
2431 return emitOpError("expected parent op with LinalgOp interface");
2432}
2433
2434//===----------------------------------------------------------------------===//
2435// IndexOp
2436//===----------------------------------------------------------------------===//
2437
2438LogicalResult IndexOp::verify() {
2439 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2440 if (!linalgOp)
2441 return emitOpError("expected parent op with LinalgOp interface");
2442 if (linalgOp.getNumLoops() <= getDim())
2443 return emitOpError("expected dim (")
2444 << getDim() << ") to be lower than the number of loops ("
2445 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2446 return success();
2447}
2448
2449OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2450 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2451 // Bail out if `linalg.index` does not have a proper parent yet at this
2452 // point, e.g., when calling `createOrFold` during IR construction in
2453 // `genericOp::build`.
2454 if (!linalgOp)
2455 return OpFoldResult{};
2456
2457 // Index of unit dims is always 0.
2458 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2459 uint64_t dim = getDim();
2460 assert(dim < loopBounds.size() && "Dim is out of bounds");
2461 if (loopBounds[dim] == 1)
2462 return IntegerAttr::get(IndexType::get(getContext()), 0);
2463
2464 return OpFoldResult{};
2465}
2466
2467/////// Operations corresponding to library calls defined with Tablegen ////////
2468
2469#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2470
2471#define GET_OP_CLASSES
2472#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2473
2474#define GET_OP_CLASSES
2475#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2476#define GET_OP_CLASSES
2477#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2478
2479AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2480 unsigned rank,
2481 MLIRContext *context) {
2482 if (maybeMap)
2483 return *maybeMap;
2484 if (rank == 0)
2485 return AffineMap::get(context);
2486 return AffineMap::getMultiDimIdentityMap(rank, context);
2487}
2488
2490mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2491 MLIRContext *context) {
2493 res.reserve(num);
2494 for (unsigned i = 0; i < num; ++i)
2495 res.push_back(getAffineDimExpr(startIdx++, context));
2496 return res;
2497}
2498
2501 auto rangeA = llvm::make_range(a.begin(), a.end());
2502 auto rangeB = llvm::make_range(b.begin(), b.end());
2503 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2504 return llvm::to_vector<4>(concatRanges);
2505}
2506
2507static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2508 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2509 ss << "view";
2510 for (auto size : memref.getShape())
2511 if (size < 0)
2512 ss << "sx";
2513 else
2514 ss << size << "x";
2515 if (failed(appendMangledType(ss, memref.getElementType())))
2516 return failure();
2517 if (auto as = memref.getMemorySpace()) {
2518 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2519 ss << "as" << attr.getInt();
2520 else
2521 return failure();
2522 }
2523 return success();
2524 }
2525 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2526 ss << "vector";
2527 llvm::interleave(
2528 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2529 if (failed(appendMangledType(ss, vec.getElementType())))
2530 return failure();
2531 return success();
2532 }
2534 ss << t;
2535 return success();
2536 }
2537 return failure();
2538}
2539
2541 assert(isa<LinalgOp>(op));
2542 std::string name(op->getName().getStringRef().str());
2543 std::string fun = "";
2544 for (NamedAttribute kv : op->getAttrs()) {
2545 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2546 fun = stringifyEnum(ufa.getValue()).str() + "_";
2547 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2548 fun = stringifyEnum(bfa.getValue()).str() + "_";
2549 }
2550 }
2551 name.reserve(128);
2552 llvm::replace(name, '.', '_');
2553 llvm::raw_string_ostream ss(name);
2554 ss << "_" << fun;
2555 for (Type t : op->getOperandTypes()) {
2556 if (failed(appendMangledType(ss, t)))
2557 return std::string();
2558 ss << "_";
2559 }
2560 name.pop_back();
2561 return name;
2562}
2563
2564//===----------------------------------------------------------------------===//
2565// Canonicalizers and Folders.
2566//===----------------------------------------------------------------------===//
2567
2568namespace {
2569struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2571
2572 LogicalResult matchAndRewrite(LinalgOp op,
2573 PatternRewriter &rewriter) const override {
2574 for (OpOperand &opOperand : op->getOpOperands()) {
2575 // Linalg "inputs" may be either tensor or memref type.
2576 // tensor<0xelt_type> is a convention that may not always mean
2577 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2578 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2579 if (!mt)
2580 continue;
2581 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2582 rewriter.eraseOp(op);
2583 return success();
2584 }
2585 }
2586 return failure();
2587 }
2588};
2589
2590/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2591/// result that is more static than the linalg op.
2592struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2593 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2594
2595 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2596 PatternRewriter &rewriter) const override {
2597 if (!tensor::canFoldIntoProducerOp(castOp))
2598 return failure();
2599
2600 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2601 if (!linalgOp)
2602 return failure();
2603
2604 // Cast can be in conditionally reachable region, if which case folding will
2605 // generate invalid code. Only conservatively fold ops in same block for
2606 // now.
2607 if (castOp->getBlock() != linalgOp->getBlock())
2608 return failure();
2609
2610 OpBuilder::InsertionGuard guard(rewriter);
2611 rewriter.setInsertionPoint(linalgOp);
2612
2613 Location loc = linalgOp.getLoc();
2614 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2615 unsigned resultNumber = resultValue.getResultNumber();
2616 auto resultType =
2617 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2618 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2619 // going from a more dynamic shape to a less dynamic shape. If the producer
2620 // for this cast, i.e. producer of the out operand, is also an operation
2621 // that folds with tensor.cast consumer (like this pattern), the cast will
2622 // continue to propagate as far up the stack as it can go.
2623 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2624 Value newOperand =
2625 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2626 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2627 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2628 linalgOp.getDpsInits().end());
2629 outputOperands[resultNumber] = newOperand;
2630 newOperands.append(outputOperands.begin(), outputOperands.end());
2631
2632 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2633 linalgOp->result_type_end());
2634 resultTypes[resultNumber] = resultType;
2635 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2636
2637 // Create a tensor.cast operation back to the original type.
2638 Value castBack = tensor::CastOp::create(
2639 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2640
2641 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2642 results[resultNumber] = castBack;
2643 rewriter.replaceOp(linalgOp, results);
2644 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2645 return success();
2646 }
2647};
2648
2649/// For each of the operand in `operands` this function maps the static sizes of
2650/// dimensions to their affine dim expressions.
2651static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2652 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2653 for (OpOperand &opOperand : operands) {
2654 if (linalgOp.isScalar(&opOperand))
2655 continue;
2656 Value src = opOperand.get();
2657 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2658 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2659
2660 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2661 // `tensor.cast` operation and source of the cast operation has a static
2662 // shape, then assign it to the `sourceShape`.
2663 auto *parentOp = src.getDefiningOp();
2664 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2665 if (parentOp) {
2666 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2667 Value castSource = castOp.getSource();
2668 auto castSourceType =
2669 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2670 if (castSourceType && castSourceType.hasStaticShape())
2671 sourceShape = castSourceType.getShape();
2672 }
2673 }
2674
2675 // If the source shape's dimension has a static shape, map the affine dim
2676 // expression to the known static size.
2677 for (unsigned i = 0; i < sourceShape.size(); i++) {
2678 if (sourceType.isDynamicDim(i))
2679 continue;
2680 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2681 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2682 }
2683 }
2684}
2685
2686/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2687/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2688/// their result types is stored in `resultTypes`. If `opOperand` requires no
2689/// change then `changeNeeded` is false and same operand is added in the
2690/// `newOperands` list.
2691static void createNewOperandWithStaticSizes(
2692 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2693 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2694 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2695 bool &changeNeeded) {
2696 Value src = opOperand->get();
2697 newOperands.push_back(src);
2698 if (linalgOp.isScalar(opOperand))
2699 return;
2700 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2701 Type resultType = sourceType;
2702 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2703 resultTypes.push_back(resultType);
2704 return;
2705 }
2706 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2707 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2708 SmallVector<int64_t> newShape;
2709 // If operand is updated with new shape, `newOperandNeeded` will be
2710 // true.
2711 bool newOperandNeeded = false;
2712 for (unsigned i = 0; i < sourceShape.size(); i++) {
2713 int64_t dimShape = sourceShape[i];
2714 AffineExpr dimExpr = sourceMap.getResult(i);
2715 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2716 newShape.push_back(dimShape);
2717 continue;
2718 }
2719 // Dimension has a dynamic shape and corresponding affine dim
2720 // expression is present in the map. So assign the size for the
2721 // given affine dim expression to the dimension.
2722 newShape.push_back(affineExprToSize[dimExpr]);
2723 newOperandNeeded = true;
2724 }
2725 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2726 sourceType.getEncoding());
2727 if (newOperandNeeded) {
2728 changeNeeded = true;
2729 // Get the new operand value given its size and element type by
2730 // casting it.
2731 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2732 unsigned index = opOperand->getOperandNumber();
2733 newOperands[index] = newOperand;
2734 }
2735 if (linalgOp.isDpsInit(opOperand))
2736 resultTypes.push_back(resultType);
2737}
2738
2739/// Static shapes for the operands can be inferred if any one of the operands
2740/// have a static shape. This can be done by referring to the affine dim
2741/// expressions for the operand.
2742struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2743 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2744
2745 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2746 PatternRewriter &rewriter) const override {
2747 if (!linalgOp.hasPureTensorSemantics())
2748 return failure();
2749
2750 // Maps must be projected permutations.
2751 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2752 return !map.isProjectedPermutation();
2753 }))
2754 return failure();
2755
2756 // Maps affine dim expressions to the static size of that dimension.
2757 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2758 Location loc = linalgOp.getLoc();
2759
2760 // For each of the affine dim expression, check if the size is known. If
2761 // known add that in the map.
2762 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2763
2764 SmallVector<Value> newOperands;
2765 SmallVector<Type> resultTypes;
2766
2767 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2768 // change in their types.
2769 bool changeNeeded = false;
2770 newOperands.reserve(linalgOp->getNumOperands());
2771 resultTypes.reserve(linalgOp.getNumDpsInits());
2772
2773 // Iterate over all the operands and update the static sizes.
2774 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2775 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2776 affineExprToSize, linalgOp, newOperands,
2777 resultTypes, changeNeeded);
2778 }
2779
2780 // If the generic op has all the required static information, no
2781 // canonicalization needed.
2782 if (!changeNeeded)
2783 return failure();
2784
2785 // Clone op.
2786 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2787 SmallVector<Value> replacements;
2788 replacements.reserve(newOp->getNumResults());
2789 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2790 Value newResult = std::get<1>(it);
2791 Value oldResult = std::get<0>(it);
2792 Type newType = newResult.getType();
2793 Type oldType = oldResult.getType();
2794 replacements.push_back(
2795 (newType != oldType)
2796 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2797 : newResult);
2798 }
2799 rewriter.replaceOp(linalgOp, replacements);
2800 return success();
2801 }
2802};
2803
2804} // namespace
2805
2806// All named ops canonicalizers and folders are auto-generated in the
2807// .cpp.inc.
2808
2809//===----------------------------------------------------------------------===//
2810// SoftmaxOp
2811//===----------------------------------------------------------------------===//
2812
2813LogicalResult SoftmaxOp::verify() {
2814 ShapedType inputType = getInputOperandType();
2815 ShapedType outputType = getOutputOperandType();
2816
2817 ArrayRef<int64_t> inputShape = inputType.getShape();
2818 ArrayRef<int64_t> outputShape = outputType.getShape();
2819 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2820 return emitOpError("incompatible output shape");
2821
2822 int64_t inputRank = getInputOperandRank();
2823 int64_t dimension = getDimension();
2824 if ((dimension < 0) || (dimension >= inputRank))
2825 return emitOpError("incorrect dimension specified");
2826
2827 return success();
2828}
2829
2830SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2831 int64_t operandRank = getInputOperandRank();
2832 SmallVector<Range> loopBounds(operandRank);
2833 Location loc = getLoc();
2834 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2835 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2836 Value source = getInput();
2837 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2838 loopBounds[dim].offset = zero;
2839 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2840 loopBounds[dim].stride = one;
2841 }
2842 return loopBounds;
2843}
2844
2845SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2846 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2847 utils::IteratorType::parallel);
2848 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2849 return iteratorTypes;
2850}
2851
2852FailureOr<TilingResult>
2853SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2854 ArrayRef<OpFoldResult> offsets,
2855 ArrayRef<OpFoldResult> sizes) {
2856 int64_t rank = getInputOperandRank();
2857 auto oneAttr = builder.getI64IntegerAttr(1);
2858 SmallVector<OpFoldResult> strides(rank, oneAttr);
2859 SmallVector<Value> tiledOperands;
2860 Operation *inputSlice =
2861 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2862 if (!inputSlice) {
2863 return emitOpError("failed to compute input slice");
2864 }
2865 tiledOperands.emplace_back(inputSlice->getResult(0));
2866 Operation *outputSlice =
2867 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2868 if (!outputSlice) {
2869 return emitOpError("failed to compute output slice");
2870 }
2871 tiledOperands.emplace_back(outputSlice->getResult(0));
2872
2873 SmallVector<Type, 4> resultTypes;
2874 if (hasPureTensorSemantics())
2875 resultTypes.push_back(tiledOperands[1].getType());
2876 Operation *tiledOp =
2877 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2878
2879 return TilingResult{
2880 {tiledOp},
2881 SmallVector<Value>(tiledOp->getResults()),
2882 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2883}
2884
2885LogicalResult SoftmaxOp::getResultTilePosition(
2886 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2887 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2888 SmallVector<OpFoldResult> &resultSizes) {
2889 if (resultNumber == 0) {
2890 resultOffsets.assign(offsets.begin(), offsets.end());
2891 resultSizes.assign(sizes.begin(), sizes.end());
2892 return success();
2893 }
2894 return failure();
2895}
2896
2897// cast(dynamic) -> static.
2898LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2899 return memref::foldMemRefCast(*this);
2900}
2901
2902LogicalResult
2903SoftmaxOp::reifyResultShapes(OpBuilder &b,
2904 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2905 SmallVector<OpFoldResult> shapes;
2906 Location loc = getOperation()->getLoc();
2907 IRRewriter rewriter(b);
2908 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2909 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2910 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2911 if (!outputShapedType.isDynamicDim(dim)) {
2912 // Static dim: Return IntegerAttr.
2913 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2914 } else {
2915 // Dynamic dim: Return Value.
2916 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2917 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2918 }
2919 }
2920 reifiedReturnShapes.emplace_back(std::move(shapes));
2921 return success();
2922}
2923
2924void SoftmaxOp::getEffects(
2925 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2926 &effects) {
2927 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2928 if (!llvm::isa<MemRefType>(operand.getType()))
2929 continue;
2930 effects.emplace_back(MemoryEffects::Read::get(),
2931 &getOperation()->getOpOperand(index), /*stage=*/0,
2932 /*effectOnFullRegion=*/true,
2934 }
2935
2936 for (OpOperand &operand : getDpsInitsMutable()) {
2937 if (!llvm::isa<MemRefType>(operand.get().getType()))
2938 continue;
2939 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2940 /*effectOnFullRegion=*/true,
2942 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2943 /*effectOnFullRegion=*/true,
2945 }
2946}
2947
2948// Helper functions for softmax decomposition.
2949// @{
2950
2951// Helper function to produce the iterator types (reduction or parallel) and
2952// affine maps for the iterators used in the decomposition of softmax.
2953// This method creates:
2954// If allParallel == true:
2955// - iterator type: {parallel, ..., parallel}
2956// - affine maps:
2957// -- identity with inputRank dimensions.
2958// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2959// where N == inputRank.
2960//
2961// If allParallel == false:
2962// - iterator type at dim(i) == parallel for i != \p dim and
2963// dim(dim) == reduction.
2964// - affine map:
2965// -- identity with inputRank dimensions.
2966// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2967// where N == inputRank.
2968static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2970 int64_t dim, bool allParallel = false) {
2971 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2972 utils::IteratorType::parallel);
2973 if (!allParallel)
2974 iteratorTypes[dim] = utils::IteratorType::reduction;
2975 MLIRContext *ctxt = builder.getContext();
2976 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2977 SmallVector<AffineExpr, 2> affineExprs;
2978 for (int i = 0; i < inputRank; i++) {
2979 if (i != dim)
2980 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2981 }
2982 auto reductionMap =
2983 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2984 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2985 return std::make_tuple(iteratorTypes, indexingMaps);
2986}
2987
2988// Helper function to produce a linalg.generic that computes a reduction on
2989// dimension \p dim with the operation type \p T.
2990template <typename T>
2991static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2992 int64_t dim) {
2993 auto inputType = cast<ShapedType>(input.getType());
2994 ArrayRef<int64_t> inputShape = inputType.getShape();
2995 int64_t inputRank = inputShape.size();
2996 auto [iteratorTypes, indexingMaps] =
2997 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2998 assert(indexingMaps.size() == 2 &&
2999 "We should have two maps: 1 for the input, 1 for the output");
3000 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3001
3002 auto genericOp = linalg::GenericOp::create(
3003 builder, loc, output.getType(), input, output, indexingMaps,
3004 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
3005 Value result = T::create(b, loc, args[0], args[1]);
3006 linalg::YieldOp::create(b, loc, result);
3007 });
3008 return genericOp.getResult(0);
3009}
3010
3011/// Produce a linalg generic that computes the second step of the softmax
3012/// decomposition: res = exp(input - max), where \p max is the max of \p input
3013/// on dimension \p dim.
3014static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3015 Value max, Value output, int64_t dim) {
3016 auto inputType = cast<ShapedType>(input.getType());
3017 ArrayRef<int64_t> inputShape = inputType.getShape();
3018 int64_t inputRank = inputShape.size();
3019 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3020 builder, inputRank, dim, /*allParallel=*/true);
3021 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3022 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3023 // Add the affine map for the output argument.
3024 indexingMaps.push_back(indexingMaps[0]);
3025 auto genericOp = linalg::GenericOp::create(
3026 builder, loc, input.getType(), ValueRange{input, max}, output,
3027 indexingMaps, iteratorTypes,
3028 [&](OpBuilder &b, Location loc, ValueRange args) {
3029 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3030 Value result = math::ExpOp::create(b, loc, diff);
3031 linalg::YieldOp::create(b, loc, result);
3032 });
3033 return genericOp.getResult(0);
3034}
3035
3036/// Produce a linalg generic that computes the final step of the softmax
3037/// decomposition.
3038/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3039/// yield n / d
3040/// }
3041static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3042 Value denominator, Value output, int64_t dim) {
3043 auto inputType = cast<ShapedType>(numerator.getType());
3044 ArrayRef<int64_t> inputShape = inputType.getShape();
3045 int64_t inputRank = inputShape.size();
3046 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3047 builder, inputRank, dim, /*allParallel=*/true);
3048 assert(indexingMaps.size() == 2 &&
3049 "We should have one map for each input (2)");
3050 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3051 // Add the affine map for the output tensor.
3052 indexingMaps.push_back(indexingMaps[0]);
3053 auto genericOp = linalg::GenericOp::create(
3054 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3055 output, indexingMaps, iteratorTypes,
3056 [&](OpBuilder &b, Location loc, ValueRange args) {
3057 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3058 linalg::YieldOp::create(b, loc, result);
3059 });
3060 return genericOp.getResult(0);
3061}
3062// @} End helper functions for softmax decomposition.
3063
3064/// Given an N-dimensional tensor x, this method converts
3065/// softmax(x) to the following sequence of operations:
3066///
3067/// 1. Compute the max of x along dimension d. This results
3068/// in a N-1 dimensional tensor m.
3069/// m = max(x, dim = d)
3070///
3071/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3072/// a N dimensional tensor z.
3073/// z = exp(x - m)
3074///
3075/// 3. Compute the sum of z along dimension d. This results in
3076/// a N-1 dimensional tensor l.
3077/// l = sum(z, dim = d)
3078///
3079/// 4. Divide z and l. This gives the N-dimensional softmax.
3080/// softmax = z / l
3081///
3082FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3083 OpBuilder::InsertionGuard guard(b);
3084 b.setInsertionPoint(*this);
3085 Location loc = getLoc();
3086 Value input = getInput();
3087 ShapedType inputType = getInputOperandType();
3088 Type elementType = inputType.getElementType();
3089 int64_t reductionDim = getDimension();
3090 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3091 Value output = getOutput();
3092 dims.erase(dims.begin() + reductionDim);
3093 // Step 1: Compute max along dim.
3094 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3095 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3096 elementType, b, loc,
3097 /*useOnlyFiniteValue=*/true);
3098 Value neutralForMaxFInit =
3099 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3100 .result();
3101 Value max =
3102 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3103
3104 // Step 2: Subtract max from input and exponentiate.
3105 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3106
3107 // Step 3: Compute sum along dim.
3108 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3109 b, loc, /*useOnlyFiniteValue=*/true);
3110 Value zeroInit =
3111 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3112 Value denominator =
3113 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3114
3115 // Step 4: Compute softmax.
3116 Value result =
3117 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3118 return SmallVector<Value>{result};
3119}
3120
3121//===----------------------------------------------------------------------===//
3122// WinogradFilterTransformOp
3123//===----------------------------------------------------------------------===//
3124
3125LogicalResult WinogradFilterTransformOp::verify() {
3126 auto filterType = cast<ShapedType>(getFilter().getType());
3127 ArrayRef<int64_t> filterShape = filterType.getShape();
3128 int64_t filterH = filterShape[getFilterHDim()];
3129 int64_t filterW = filterShape[getFilterWDim()];
3130 WinogradConv2DFmr fmr = getFmr();
3131 int64_t m, r;
3132 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3133
3134 if (filterH != r && filterH != 1)
3135 return emitOpError("expect filter height either equals to r or 1");
3136 if (filterW != r && filterW != 1)
3137 return emitOpError("expect filter width either equals to r or 1");
3138 if (filterH == 1 && filterW == 1)
3139 return emitOpError("expect either filter height or width equals to r");
3140
3141 SmallVector<int64_t> expectedOutputShape;
3142 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3143 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3144 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3145 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3146
3147 auto outputType = cast<ShapedType>(getOutput().getType());
3148 ArrayRef<int64_t> outputShape = outputType.getShape();
3149 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3150 return emitOpError("the output shape is not expected");
3151 }
3152 return success();
3153}
3154
3155SmallVector<Range>
3156WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3157 Location loc = getLoc();
3158 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3159 IntegerAttr oneAttr = builder.getIndexAttr(1);
3160 Value filter = getFilter();
3161 int64_t filterRank = getFilterOperandRank();
3162 SmallVector<Range> loopBounds(filterRank);
3163 for (unsigned dim = 0; dim < filterRank; ++dim) {
3164 loopBounds[dim].offset = zeroAttr;
3165 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3166 loopBounds[dim].stride = oneAttr;
3167 }
3168 return loopBounds;
3169}
3170
3171SmallVector<utils::IteratorType>
3172WinogradFilterTransformOp::getLoopIteratorTypes() {
3173 int64_t filterRank = getFilterOperandRank();
3174 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3175 utils::IteratorType::parallel);
3176 return iteratorTypes;
3177}
3178
3179LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3180 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3181 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3182 SmallVector<OpFoldResult> &resultSizes) {
3183 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3184 ShapedType filterType = getFilterOperandType();
3185 ArrayRef<int64_t> filterShape = filterType.getShape();
3186 int64_t filterH = filterShape[getFilterHDim()];
3187 int64_t filterW = filterShape[getFilterWDim()];
3188 WinogradConv2DFmr fmr = getFmr();
3189 int64_t m, r;
3190 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3191 int64_t alpha = m + r - 1;
3192 int64_t alphaH = filterH != 1 ? alpha : 1;
3193 int64_t alphaW = filterW != 1 ? alpha : 1;
3194 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3195 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3196
3197 resultOffsets.append(
3198 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3199 resultSizes.append(
3200 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3201
3202 return success();
3203}
3204
3205/// Implement tiling for winograd_filter_transform
3206/// The input of winograd_filter_transform is (F, KH, KW, C).
3207/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3208/// Users can specify the tile sizes of F and C.
3209/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3210/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3211FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3212 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3213 ArrayRef<OpFoldResult> sizes) {
3214 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3215 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3216 ShapedType filterType = getFilterOperandType();
3217 ArrayRef<int64_t> filterShape = filterType.getShape();
3218 int64_t filterH = filterShape[getFilterHDim()];
3219 int64_t filterW = filterShape[getFilterWDim()];
3220 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3221 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3222 SmallVector<Value> tiledOperands;
3223 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3224
3225 sliceOffsets.append(
3226 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3227 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3228 sizes[getFilterCDim()]});
3229 int64_t filterRank = getFilterOperandRank();
3230 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3231 Location loc = getLoc();
3232 auto filterSlice = tensor::ExtractSliceOp::create(
3233 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3234 tiledOperands.emplace_back(filterSlice);
3235
3236 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3237 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3238 resultSizes)))
3239 return failure();
3240
3241 int64_t outputRank = getOutputOperandRank();
3242 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3243 auto outputSlice = tensor::ExtractSliceOp::create(
3244 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3245 tiledOperands.emplace_back(outputSlice);
3246
3247 SmallVector<Type> resultTypes;
3248 resultTypes.push_back(tiledOperands[1].getType());
3249 Operation *tiledOp =
3250 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3251
3252 return TilingResult{
3253 {tiledOp},
3254 SmallVector<Value>(tiledOp->getResults()),
3255 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3256}
3257
3258//===----------------------------------------------------------------------===//
3259// WinogradInputTransformOp
3260//===----------------------------------------------------------------------===//
3261
3262LogicalResult WinogradInputTransformOp::verify() {
3263 auto inputType = cast<ShapedType>(getInput().getType());
3264 ArrayRef<int64_t> inputShape = inputType.getShape();
3265 int64_t inputH = inputShape[getInputHDim()];
3266 int64_t inputW = inputShape[getInputWDim()];
3267 WinogradConv2DFmr fmr = getFmr();
3268 int64_t m, r;
3269 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3270 int64_t tileSize = m + r - 1;
3271
3272 auto outputType = cast<ShapedType>(getOutput().getType());
3273 ArrayRef<int64_t> outputShape = outputType.getShape();
3274 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3275 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3276
3277 SmallVector<int64_t> expectedOutputShape(6, inputH);
3278 if (ShapedType::isDynamic(inputH)) {
3279 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3280 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3281 } else {
3282 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3283 expectedOutputShape[getOutputTileHDim()] =
3284 leftTransform ? (inputH - (r - 1)) / m : inputH;
3285 }
3286 if (ShapedType::isDynamic(inputW)) {
3287 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3288 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3289 } else {
3290 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3291 expectedOutputShape[getOutputTileWDim()] =
3292 rightTransform ? (inputW - (r - 1)) / m : inputW;
3293 }
3294 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3295 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3296
3297 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3298 return emitOpError("the output shape is not expected");
3299 }
3300 return success();
3301}
3302
3303SmallVector<Range>
3304WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3305 Location loc = getLoc();
3306 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3307 IntegerAttr oneAttr = builder.getIndexAttr(1);
3308 Value output = getOutput();
3309 int64_t outputRank = getOutputOperandRank();
3310 SmallVector<Range> loopBounds(outputRank);
3311 for (unsigned dim = 0; dim < outputRank; ++dim) {
3312 loopBounds[dim].offset = zeroAttr;
3313 // alphaH, alphaW, tileH, tileW, N, C
3314 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3315 loopBounds[dim].stride = oneAttr;
3316 }
3317 return loopBounds;
3318}
3319
3320SmallVector<utils::IteratorType>
3321WinogradInputTransformOp::getLoopIteratorTypes() {
3322 int64_t outputRank = getOutputOperandRank();
3323 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3324 utils::IteratorType::parallel);
3325 return iteratorTypes;
3326}
3327
3328LogicalResult WinogradInputTransformOp::getResultTilePosition(
3329 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3330 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3331 SmallVector<OpFoldResult> &resultSizes) {
3332 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3333 ShapedType outputType = getOutputOperandType();
3334 ArrayRef<int64_t> outputShape = outputType.getShape();
3335 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3336 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3337
3338 WinogradConv2DFmr fmr = getFmr();
3339 int64_t m, r;
3340 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3341 int64_t alpha = m + r - 1;
3342 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3343 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3344
3345 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3346 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3347
3348 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3349 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3350 offsets[getOutputCDim()]});
3351 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3352 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3353 sizes[getOutputCDim()]});
3354
3355 return success();
3356}
3357
3358/// Implement tiling for winograd_input_transform
3359/// The input of winograd_input_transform is (N, H, W, C).
3360/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3361/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3362/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3363/// the values for the sizes of tileH, tileW, N, C for one tile.
3364FailureOr<TilingResult>
3365WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3366 ArrayRef<OpFoldResult> offsets,
3367 ArrayRef<OpFoldResult> sizes) {
3368 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3369 WinogradConv2DFmr fmr = getFmr();
3370 int64_t m, r;
3371 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3372
3373 ShapedType outputType = getOutputOperandType();
3374 ArrayRef<int64_t> outputShape = outputType.getShape();
3375 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3376 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3377
3378 Location loc = getLoc();
3379 MLIRContext *context = builder.getContext();
3380 auto identityAffineMap =
3381 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3382 auto offsetAffineMap =
3383 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3384 Value mappedOffsetH = affine::makeComposedAffineApply(
3385 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3386 offsets[getOutputTileHDim()]);
3387 Value mappedOffsetW = affine::makeComposedAffineApply(
3388 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3389 offsets[getOutputTileWDim()]);
3390 auto sizeAffineMap = AffineMap::get(
3391 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3392 Value mappedSizeH = affine::makeComposedAffineApply(
3393 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3394 Value mappedSizeW = affine::makeComposedAffineApply(
3395 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3396
3397 SmallVector<Value> tiledOperands;
3398 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3399
3400 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3401 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3402 sliceOffsets.append(
3403 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3404 OpFoldResult sizeH =
3405 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3406 OpFoldResult sizeW =
3407 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3408 sliceSizes.append(
3409 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3410 int64_t inputRank = getInputOperandRank();
3411 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3412 auto inputSlice = tensor::ExtractSliceOp::create(
3413 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3414 tiledOperands.emplace_back(inputSlice);
3415
3416 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3417 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3418 resultSizes)))
3419 return failure();
3420
3421 int64_t outputRank = getOutputOperandRank();
3422 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3423 auto outputSlice = tensor::ExtractSliceOp::create(
3424 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3425 tiledOperands.emplace_back(outputSlice);
3426
3427 SmallVector<Type> resultTypes;
3428 resultTypes.push_back(tiledOperands[1].getType());
3429 Operation *tiledOp =
3430 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3431
3432 return TilingResult{
3433 {tiledOp},
3434 SmallVector<Value>(tiledOp->getResults()),
3435 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3436}
3437
3438//===----------------------------------------------------------------------===//
3439// WinogradOutputTransformOp
3440//===----------------------------------------------------------------------===//
3441
3442LogicalResult WinogradOutputTransformOp::verify() {
3443 auto valueType = cast<ShapedType>(getValue().getType());
3444 ArrayRef<int64_t> valueShape = valueType.getShape();
3445 int64_t valueH = valueShape[getValueAlphaHDim()];
3446 int64_t valueW = valueShape[getValueAlphaWDim()];
3447 int64_t valueTileH = valueShape[getValueTileHDim()];
3448 int64_t valueTileW = valueShape[getValueTileWDim()];
3449 WinogradConv2DFmr fmr = getFmr();
3450 int64_t m, r;
3451 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3452 bool leftTransform = valueH != 1;
3453 bool rightTransform = valueW != 1;
3454
3455 int64_t outputRank = getOutputOperandRank();
3456 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3457 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3458 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3459 } else {
3460 if (valueH != (leftTransform ? m + r - 1 : 1))
3461 return emitOpError("expect input height equals to input tile size");
3462 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3463 }
3464 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3465 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3466 } else {
3467 if (valueW != (rightTransform ? m + r - 1 : 1))
3468 return emitOpError("expect input width equals to input tile size");
3469 expectedOutputShape[getOutputWDim()] =
3470 (rightTransform ? m : 1) * valueTileW;
3471 }
3472 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3473 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3474
3475 auto outputType = cast<ShapedType>(getOutput().getType());
3476 ArrayRef<int64_t> outputShape = outputType.getShape();
3477 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3478 return emitOpError("the output shape is not expected");
3479 }
3480 return success();
3481}
3482
3483SmallVector<Range>
3484WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3485 Location loc = getLoc();
3486 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3487 IntegerAttr oneAttr = builder.getIndexAttr(1);
3488 Value value = getValue();
3489 int64_t valueRank = getValueOperandRank();
3490 SmallVector<Range> loopBounds(valueRank);
3491 for (unsigned dim = 0; dim < valueRank; ++dim) {
3492 loopBounds[dim].offset = zeroAttr;
3493 // alphaH, alphaW, tileH, tileW, N, F
3494 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3495 loopBounds[dim].stride = oneAttr;
3496 }
3497 return loopBounds;
3498}
3499
3500SmallVector<utils::IteratorType>
3501WinogradOutputTransformOp::getLoopIteratorTypes() {
3502 int64_t valueRank = getValueOperandRank();
3503 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3504 utils::IteratorType::parallel);
3505 return iteratorTypes;
3506}
3507
3508LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3509 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3510 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3511 SmallVector<OpFoldResult> &resultSizes) {
3512 WinogradConv2DFmr fmr = getFmr();
3513 int64_t m, r;
3514 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3515
3516 Location loc = getLoc();
3517 MLIRContext *context = builder.getContext();
3518 auto identityAffineMap =
3519 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3520 auto affineMap =
3521 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3522
3523 ShapedType valueType = getValueOperandType();
3524 ArrayRef<int64_t> valueShape = valueType.getShape();
3525 int64_t valueH = valueShape[0];
3526 int64_t valueW = valueShape[1];
3527 Value mappedOffsetH = affine::makeComposedAffineApply(
3528 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3529 offsets[getValueTileHDim()]);
3530 Value mappedOffsetW = affine::makeComposedAffineApply(
3531 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3532 offsets[getValueTileWDim()]);
3533 Value mappedSizeH = affine::makeComposedAffineApply(
3534 builder, loc, affineMap, sizes[getValueTileHDim()]);
3535 Value mappedSizeW = affine::makeComposedAffineApply(
3536 builder, loc, affineMap, sizes[getValueTileWDim()]);
3537
3538 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3539 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3540 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3541 OpFoldResult sizeH =
3542 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3543 OpFoldResult sizeW =
3544 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3545
3546 resultOffsets.append(
3547 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3548 resultSizes.append(
3549 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3550 return success();
3551}
3552
3553/// Implement tiling for winograd_output_transform
3554/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3555/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3556/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3557/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3558/// for the sizes of tileH, tileW, N, F for one tile.
3559FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3560 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3561 ArrayRef<OpFoldResult> sizes) {
3562 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3563 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3564 Location loc = getLoc();
3565 SmallVector<Value> tiledOperands;
3566 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3567
3568 ShapedType valueType = getValueOperandType();
3569 ArrayRef<int64_t> valueShape = valueType.getShape();
3570 int64_t alphaH = valueShape[getValueAlphaHDim()];
3571 int64_t alphaW = valueShape[getValueAlphaWDim()];
3572 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3573 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3574
3575 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3576 offsets[getValueTileWDim()], offsets[getValueNDim()],
3577 offsets[getValueFDim()]});
3578 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3579 sizes[getValueTileWDim()], sizes[getValueNDim()],
3580 sizes[getValueFDim()]});
3581 int64_t valueRank = getValueOperandRank();
3582 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3583 auto valueSlice = tensor::ExtractSliceOp::create(
3584 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3585 tiledOperands.emplace_back(valueSlice);
3586
3587 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3588 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3589 resultSizes)))
3590 return failure();
3591
3592 int64_t outputRank = getOutputOperandRank();
3593 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3594 auto outputSlice = tensor::ExtractSliceOp::create(
3595 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3596 tiledOperands.emplace_back(outputSlice);
3597
3598 SmallVector<Type> resultTypes;
3599 resultTypes.push_back(tiledOperands[1].getType());
3600 Operation *tiledOp =
3601 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3602
3603 return TilingResult{
3604 {tiledOp},
3605 SmallVector<Value>(tiledOp->getResults()),
3606 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3607}
3608
3609//===----------------------------------------------------------------------===//
3610// LinalgDialect
3611// TODO: Merge with the LinalgDialect block at the bottom
3612//===----------------------------------------------------------------------===//
3613
3614// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3615static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3616 auto explicitRange = subMap.getResults();
3617 auto defaultRange = fullMap.getResults();
3618 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3619 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3620 llvm::set_union(explicitSet, defaultSet);
3621 return explicitSet == defaultSet;
3622}
3623
3624/// Check if the user defined map is valid broadcast map. Here broadcast
3625/// indexing maps are defined in context of corresponding default indexing maps
3626/// for the given Op. This way the check becomes very simple i.e just check the
3627/// number of result dims.
3628/// Returns true if the explictMap is broadcasted with respect to the
3629/// defaultMap.
3630static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3631 return explictMap.getNumResults() < defaultMap.getNumResults();
3632}
3633
3634/// Verifies the broadcast and transpose semantic sepecified by the explicit
3635/// indexing map for the MatmulOp \p op for each operand specified by \p
3636/// opIndex.
3637static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3638 unsigned opIndex) {
3639 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3640 SmallVector<AffineMap, 3> defaultIndexingMaps =
3641 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3642
3643 auto opIndexingMap = opIndexingMaps[opIndex];
3644 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3645 // Check general validity of indexing map results.
3646 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3647 return matmulOp->emitOpError()
3648 << "Unexpected dim expression in map result.";
3649
3650 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3651 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3652 return matmulOp->emitOpError()
3653 << "Invalid broadcast requested, should be (d2).";
3654 }
3655 return success();
3656 }
3657 return success();
3658}
3659
3660// Check general validity of input indexing map of
3661// BatchMatmulOp/BatchReduceMatmulOp.
3662template <typename OpTy>
3663static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3664 AffineMap opIndexingMap,
3665 AffineMap defaultIndexingMap, bool isLHS) {
3666 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3667 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3668 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3669 // Check the result dims are valid.
3670 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3671 return batchVariantMatmulOp->emitOpError()
3672 << "Unexpected result dim expression (outside the set of default "
3673 "result dims).";
3674
3675 // Check for valid number of result dims of input maps.
3676 if (opIndexingMap.getNumResults() > 3)
3677 return batchVariantMatmulOp->emitOpError()
3678 << "no. of result dim expressions exceeds 3.";
3679
3680 auto hasValidBatchDim = [](AffineMap map) {
3681 AffineExpr batchDim = map.getResult(0);
3682 return batchDim.isFunctionOfDim(0);
3683 };
3684
3685 // Check if the requested broadcast is valid.
3686 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3687 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3688 return batchVariantMatmulOp->emitOpError()
3689 << "Invalid broadcast requested.";
3690 } else if (!hasValidBatchDim(opIndexingMap)) {
3691 return batchVariantMatmulOp->emitOpError()
3692 << "Invalid batch dimension expression.";
3693 }
3694 return success();
3695}
3696
3697/// This function checks if the given AffineMap for the output of a
3698/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3699/// dimensions and if the output map result dimensions are valid.
3700template <typename OpTy>
3701static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3702 AffineMap opIndexingMap) {
3703 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3704 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3705 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3706 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3707 opIndexingMap.getNumResults() != 3) {
3708
3709 return batchVariantMatmulOp->emitOpError()
3710 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3711 << ").";
3712 }
3713 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3714 opIndexingMap.getNumResults() != 2) {
3715 return batchVariantMatmulOp->emitOpError()
3716 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3717 << ").";
3718 }
3719
3720 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3721 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3722 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3723 outputMap.getResult(1).isFunctionOfDim(1) &&
3724 outputMap.getResult(2).isFunctionOfDim(2)
3725 : outputMap.getResult(0).isFunctionOfDim(1) &&
3726 outputMap.getResult(1).isFunctionOfDim(2);
3727 };
3728
3729 if (!areValidOutputResultDim(opIndexingMap)) {
3730 return batchVariantMatmulOp->emitOpError()
3731 << "Invalid output map result dimension.";
3732 }
3733
3734 return success();
3735}
3736
3737/// Verifies the broadcast and transpose semantic specified by the explicit
3738/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3739/// specified by opIndex.
3740template <typename OpTy>
3741static LogicalResult
3743 unsigned opIndex) {
3744 SmallVector<AffineMap, 3> opIndexingMaps =
3745 batchVariantMatmulOp.getIndexingMapsArray();
3746 SmallVector<AffineMap, 3> defaultIndexingMaps =
3747 batchVariantMatmulOp.getDefaultIndexingMaps(
3748 batchVariantMatmulOp->getContext());
3749
3750 if (opIndexingMaps.size() != 3)
3751 return batchVariantMatmulOp->emitOpError()
3752 << "Indexing_map attribute must have 3 affine maps.";
3753
3754 auto opIndexingMap = opIndexingMaps[opIndex];
3755 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3756
3757 if (opIndex == 2 &&
3758 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3759 return failure();
3760
3761 if (opIndex != 2 &&
3762 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3763 defaultIndexingMap, opIndex == 0)))
3764 return failure();
3765
3766 return success();
3767}
3768
3769namespace mlir {
3770namespace linalg {
3771
3772std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3773 if (m == 2 && r == 3)
3774 return WinogradConv2DFmr::F_2_3;
3775 if (m == 4 && r == 3)
3776 return WinogradConv2DFmr::F_4_3;
3777 if (m == 2 && r == 5)
3778 return WinogradConv2DFmr::F_2_5;
3779 return std::nullopt;
3780}
3781
3782std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3783 switch (fmr) {
3784 case WinogradConv2DFmr::F_2_3:
3785 return {2, 3};
3786 case WinogradConv2DFmr::F_4_3:
3787 return {4, 3};
3788 case WinogradConv2DFmr::F_2_5:
3789 return {2, 5};
3790 }
3791 llvm_unreachable("Unkown WinogradConv2DFmr");
3792}
3793
3794//===----------------------------------------------------------------------===//
3795// MatMulOp
3796//===----------------------------------------------------------------------===//
3797
3798static FailureOr<SmallVector<SmallVector<int64_t>>>
3801 for (auto map : maps) {
3802 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3803 if (!attr)
3804 return failure();
3806 for (auto result : attr.getAffineMap().getResults()) {
3807 auto dim = dyn_cast<AffineDimExpr>(result);
3808 if (!dim)
3809 return failure();
3810 pos.push_back(dim.getPosition());
3811 }
3812 positions.push_back(pos);
3813 }
3814 return positions;
3815}
3816
3817/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3818SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3819 AffineExpr d0, d1, d2;
3820 SmallVector<AffineMap> indexingMaps;
3821 bindDims(context, d0, d1, d2);
3822 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3823 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3824 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3825 return indexingMaps;
3826}
3827
3828bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3829 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3830 if (!maps)
3831 return false;
3832 if (maps.size() != 3)
3833 return false;
3834 auto positions = getAffineResultPositions(maps);
3835 if (failed(positions))
3836 return false;
3837 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3838 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3839 (*positions)[2] == SmallVector<int64_t>{0, 1};
3840}
3841
3842SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3843 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3844 utils::IteratorType::parallel,
3845 utils::IteratorType::reduction};
3846}
3847
3848unsigned MatmulOp::getNumRegionArgs() { return 3; }
3849
3850std::string MatmulOp::getLibraryCallName() {
3851 return generateLibraryCallName(getOperation());
3852}
3853
3854bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3855
3856/// Check if the op has broadcast and/or transpose semantic. Returns true if
3857/// the user defined indexing maps are not equal to default map.
3858bool MatmulOp::hasUserDefinedMaps() {
3859 SmallVector<AffineMap, 3> defaultMaps =
3860 getDefaultIndexingMaps(this->getContext());
3861 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3862 return defaultMaps != explicitMaps;
3863}
3864
3865/// Implements the block region builder for the MatmulOp. This is called by
3866/// 'fillStructuredOpRegion'.
3867void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3868 ArrayRef<NamedAttribute> attrs,
3869 function_ref<InFlightDiagnostic()> emitError) {
3870 if (emitError && block.getNumArguments() != 3) {
3871 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3872 << block.getNumArguments();
3873 return;
3874 }
3875 assert(block.getNumArguments() == 3 &&
3876 "MatmulOp regionBuilder expects 3 args");
3877 RegionBuilderHelper helper(b, block);
3878 SmallVector<Value> yields;
3879
3880 TypeFn castVal = TypeFn::cast_signed;
3881 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3882 return attr.getName() == "cast";
3883 });
3884 if (castIter != attrs.end()) {
3885 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3886 castVal = attr.getValue();
3887 }
3888
3889 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3890 block.getArgument(0));
3891 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3892 block.getArgument(1));
3893 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3894 if (!value1 || !value2 || !value3)
3895 return;
3896 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3897 value3, emitError);
3898 if (!value4)
3899 return;
3900 yields.push_back(value4);
3901 helper.yieldOutputs(yields);
3902}
3903
3904/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3905/// broadcast map must include K dimension.
3906/// TODO: Strict inclusion of K dimension in the broadcast map is not
3907/// necessary for both input matrices simultaneously. We can relax this
3908/// condition to have K dimension for one input matrix map and infer the K
3909/// dimension for other input matrix map from the one already having K
3910/// dimension.
3911bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3912 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3913 AffineExpr expr = bcastMap.getResult(0);
3914 // Invalid map if the common dimension of matmul not found.
3915 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3916}
3917
3918static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3919 if (parser.parseOptionalKeyword("indexing_maps"))
3920 return ArrayAttr{
3921 nullptr}; // Success in case indexing_maps was not provided.
3922
3923 ArrayAttr arrayAttr;
3924 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3925 return failure();
3926
3927 if (llvm::any_of(arrayAttr,
3928 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3929 return parser.emitError(parser.getCurrentLocation())
3930 << "element of indexing_maps array is not an affine_map";
3931
3932 return arrayAttr;
3933}
3934
3935ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3936 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3937 if (failed(indexingMapsAttr))
3938 return failure();
3939
3940 if (*indexingMapsAttr == nullptr) {
3941 auto indexingMapAttrs = llvm::map_to_vector(
3942 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3943 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3944 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3945 }
3946
3947 result.addAttribute("indexing_maps", *indexingMapsAttr);
3948 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3949 MatmulOp::getRegionBuilder());
3950}
3951
3952void MatmulOp::print(OpAsmPrinter &p) {
3953 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3954 MatmulOp::getDefaultIndexingMaps(getContext()),
3955 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3956 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3957 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3958
3959 std::array<StringRef, 3> elidedAttrs = {
3960 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3961 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3962 elidedAttrs);
3963}
3964
3965/// Verify the user defined indexing maps.
3966LogicalResult MatmulOp::verify() {
3967 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3968 if (!hasUserDefinedMaps())
3969 return success();
3970
3971 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3972 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3973 return failure();
3974 }
3975 return success();
3976}
3977
3978LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3979 return memref::foldMemRefCast(*this);
3980}
3981
3982void MatmulOp::getEffects(
3983 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3984 &effects) {
3985 if (hasPureTensorSemantics())
3986 return;
3987 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3988}
3989
3990Speculation::Speculatability MatmulOp::getSpeculatability() {
3991 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3992}
3993
3994SmallVector<AffineMap>
3995MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3996 AffineExpr d0, d1, d2;
3997 MLIRContext *context = builder.getContext();
3998 bindDims(context, d0, d1, d2);
3999 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
4000 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
4001 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4002 return {mapLHS, mapRHS, mapOut};
4003}
4004
4006 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4007 if (!maps)
4008 return false;
4009 if (maps.size() != 3)
4010 return false;
4011 auto positions = getAffineResultPositions(maps);
4012 if (failed(positions))
4013 return false;
4014 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4015 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4016 (*positions)[2] == SmallVector<int64_t>{0, 1};
4017}
4018
4021 ValueRange inputs, ValueRange outputs,
4022 ArrayRef<NamedAttribute> attributes) {
4023 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4024 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4025}
4026
4029 ValueRange inputs, ValueRange outputs,
4030 ArrayRef<NamedAttribute> attributes) {
4031 OperationState state(location, getOperationName());
4032 build(builder, state, inputs, outputs, attributes);
4033 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4034 assert(res && "builder didn't return the right type");
4035 return res;
4036}
4037
4040 TypeRange resultTensorTypes,
4041 ValueRange inputs, ValueRange outputs,
4042 ArrayRef<NamedAttribute> attributes) {
4043 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4044 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4045}
4046
4049 TypeRange resultTensorTypes, ValueRange inputs,
4050 ValueRange outputs,
4051 ArrayRef<NamedAttribute> attributes) {
4052 OperationState state(location, getOperationName());
4053 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4054 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4055 assert(res && "builder didn't return the right type");
4056 return res;
4057}
4058
4061 TypeRange resultTensorTypes,
4062 ValueRange inputs, ValueRange outputs,
4063 Attribute cast,
4064 ArrayRef<NamedAttribute> attributes) {
4065 result.addAttribute("cast", cast);
4066 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4067 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4068}
4069
4072 TypeRange resultTensorTypes, ValueRange inputs,
4073 ValueRange outputs, Attribute cast,
4074 ArrayRef<NamedAttribute> attributes) {
4075 OperationState state(location, getOperationName());
4076 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4077 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4078 assert(res && "builder didn't return the right type");
4079 return res;
4080}
4081
4083 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4085 op->getAttr("indexing_maps"));
4086}
4087
4089MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4090 AffineExpr d0, d1, d2;
4091 MLIRContext *context = builder.getContext();
4092 bindDims(context, d0, d1, d2);
4093 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4094 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4095 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4096 return {mapLHS, mapRHS, mapOut};
4097}
4098
4100 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4101 if (!maps)
4102 return false;
4103 if (maps.size() != 3)
4104 return false;
4105 auto positions = getAffineResultPositions(maps);
4106 if (failed(positions))
4107 return false;
4108 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4109 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4110 (*positions)[2] == SmallVector<int64_t>{0, 1};
4111}
4112
4115 ValueRange inputs, ValueRange outputs,
4116 ArrayRef<NamedAttribute> attributes) {
4117 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4118 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4119}
4120
4123 ValueRange inputs, ValueRange outputs,
4124 ArrayRef<NamedAttribute> attributes) {
4125 OperationState state(location, getOperationName());
4126 build(builder, state, inputs, outputs, attributes);
4127 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4128 assert(res && "builder didn't return the right type");
4129 return res;
4130}
4131
4134 TypeRange resultTensorTypes,
4135 ValueRange inputs, ValueRange outputs,
4136 ArrayRef<NamedAttribute> attributes) {
4137 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4138 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4139}
4140
4143 TypeRange resultTensorTypes, ValueRange inputs,
4144 ValueRange outputs,
4145 ArrayRef<NamedAttribute> attributes) {
4146 OperationState state(location, getOperationName());
4147 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4148 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4149 assert(res && "builder didn't return the right type");
4150 return res;
4151}
4152
4155 TypeRange resultTensorTypes,
4156 ValueRange inputs, ValueRange outputs,
4157 Attribute cast,
4158 ArrayRef<NamedAttribute> attributes) {
4159 result.addAttribute("cast", cast);
4160 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4161 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4162}
4163
4166 TypeRange resultTensorTypes, ValueRange inputs,
4167 ValueRange outputs, Attribute cast,
4168 ArrayRef<NamedAttribute> attributes) {
4169 OperationState state(location, getOperationName());
4170 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4171 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4172 assert(res && "builder didn't return the right type");
4173 return res;
4174}
4175
4177 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4179 op->getAttr("indexing_maps"));
4180}
4181
4183BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4184 AffineExpr d0, d1, d2, d3;
4185 MLIRContext *context = builder.getContext();
4186 bindDims(context, d0, d1, d2, d3);
4187 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4188 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4189 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4190 return {mapLHS, mapRHS, mapOut};
4191}
4192
4194 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4195 if (!maps)
4196 return false;
4197 if (maps.size() != 3)
4198 return false;
4199 auto positions = getAffineResultPositions(maps);
4200 if (failed(positions))
4201 return false;
4202 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4203 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4204 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4205}
4206
4208 OpBuilder &builder, OperationState &result, ValueRange inputs,
4209 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4210 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4211 BatchMatmulOp::getRegionBuilder(),
4212 getDefaultIndexingMaps(builder));
4213}
4214
4217 ValueRange inputs, ValueRange outputs,
4218 ArrayRef<NamedAttribute> attributes) {
4219 OperationState state(location, getOperationName());
4220 build(builder, state, inputs, outputs, attributes);
4221 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4222 assert(res && "builder didn't return the right type");
4223 return res;
4224}
4225
4227 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4228 ValueRange inputs, ValueRange outputs,
4229 ArrayRef<NamedAttribute> attributes) {
4230 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4231 BatchMatmulOp::getRegionBuilder(),
4232 getDefaultIndexingMaps(builder));
4233}
4234
4237 TypeRange resultTensorTypes, ValueRange inputs,
4238 ValueRange outputs,
4239 ArrayRef<NamedAttribute> attributes) {
4240 OperationState state(location, getOperationName());
4241 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4242 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4243 assert(res && "builder didn't return the right type");
4244 return res;
4245}
4246
4248 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4249 ValueRange inputs, ValueRange outputs, Attribute cast,
4250 ArrayRef<NamedAttribute> attributes) {
4251 result.addAttribute("cast", cast);
4252 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4253 BatchMatmulOp::getRegionBuilder(),
4254 getDefaultIndexingMaps(builder));
4255}
4256
4259 TypeRange resultTensorTypes, ValueRange inputs,
4260 ValueRange outputs, Attribute cast,
4261 ArrayRef<NamedAttribute> attributes) {
4262 OperationState state(location, getOperationName());
4263 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4264 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4265 assert(res && "builder didn't return the right type");
4266 return res;
4267}
4268
4270 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4272 op->getAttr("indexing_maps"));
4273}
4274
4276BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4277 AffineExpr d0, d1, d2, d3;
4278 MLIRContext *context = builder.getContext();
4279 bindDims(context, d0, d1, d2, d3);
4280 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4281 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4282 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4283 return {mapLHS, mapRHS, mapOut};
4284}
4285
4287 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4288 if (!maps)
4289 return false;
4290 if (maps.size() != 3)
4291 return false;
4292 auto positions = getAffineResultPositions(maps);
4293 if (failed(positions))
4294 return false;
4295 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4296 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4297 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4298}
4299
4301 OpBuilder &builder, OperationState &result, ValueRange inputs,
4302 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4303 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4304 BatchMatmulOp::getRegionBuilder(),
4305 getDefaultIndexingMaps(builder));
4306}
4307
4310 ValueRange inputs, ValueRange outputs,
4311 ArrayRef<NamedAttribute> attributes) {
4312 OperationState state(location, getOperationName());
4313 build(builder, state, inputs, outputs, attributes);
4314 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4315 assert(res && "builder didn't return the right type");
4316 return res;
4317}
4318
4320 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4321 ValueRange inputs, ValueRange outputs,
4322 ArrayRef<NamedAttribute> attributes) {
4323 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4324 BatchMatmulOp::getRegionBuilder(),
4325 getDefaultIndexingMaps(builder));
4326}
4327
4330 TypeRange resultTensorTypes, ValueRange inputs,
4331 ValueRange outputs,
4332 ArrayRef<NamedAttribute> attributes) {
4333 OperationState state(location, getOperationName());
4334 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4335 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4336 assert(res && "builder didn't return the right type");
4337 return res;
4338}
4339
4341 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4342 ValueRange inputs, ValueRange outputs, Attribute cast,
4343 ArrayRef<NamedAttribute> attributes) {
4344 result.addAttribute("cast", cast);
4345 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4346 BatchMatmulOp::getRegionBuilder(),
4347 getDefaultIndexingMaps(builder));
4348}
4349
4352 TypeRange resultTensorTypes, ValueRange inputs,
4353 ValueRange outputs, Attribute cast,
4354 ArrayRef<NamedAttribute> attributes) {
4355 OperationState state(location, getOperationName());
4356 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4357 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4358 assert(res && "builder didn't return the right type");
4359 return res;
4360}
4361
4363 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4365 op->getAttr("indexing_maps"));
4366}
4367
4368//===----------------------------------------------------------------------===//
4369// ContractOp
4370//===----------------------------------------------------------------------===//
4371
4372SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4373 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4374 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4375 // domains are all the same, and each implements a projected permutation.
4376 // Each iteration space dim must occur for at least one operand and either
4377 // takes part in a contraction/reduction or else has parallel iteration type.
4378 // We have that a dim is a contraction/reduction dim if and only if the dim
4379 // occurs for the output operand. We use this fact for fast inference:
4380 // NB: In case we allow dims to occur solely for one input, the above still
4381 // holds: per the einsum semantics, these are reduction dims as well.
4382 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4383 for (auto result : outAffineMap.getResults()) {
4384 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4385 assert(dimExpr && "affine_map is a projected permutation");
4386 dimsInOutput[dimExpr.getPosition()] = true;
4387 }
4388
4390 for (auto dimOccursInOutput : dimsInOutput)
4391 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4392 : utils::IteratorType::reduction);
4393
4394 return iteratorTypes;
4395}
4396
4397unsigned ContractOp::getNumRegionArgs() { return 3; }
4398
4399/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4400void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4401 ArrayRef<NamedAttribute> attrs,
4402 function_ref<InFlightDiagnostic()> emitError) {
4403 if (emitError && block.getNumArguments() != 3) {
4404 emitError() << "ContractOp regionBuilder expects 3 args, got "
4405 << block.getNumArguments();
4406 return;
4407 }
4408 assert(block.getNumArguments() == 3 &&
4409 "ContractOp regionBuilder expects 3 args");
4410 RegionBuilderHelper helper(b, block);
4411
4412 TypeFn castSignedness = TypeFn::cast_signed;
4413 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4414 return attr.getName() == "cast";
4415 });
4416 if (castIter != attrs.end()) {
4417 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4418 castSignedness = attr.getValue();
4419 }
4420
4421 // TODO: Support fields with operators besides mult & add.
4422 Type outType = block.getArgument(2).getType();
4423 Value lhsAtOutType =
4424 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4425 Value rhsAtOutType =
4426 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4427 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4428 rhsAtOutType, emitError);
4429 if (!productAtOutType)
4430 return;
4431 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4432 productAtOutType, emitError);
4433 if (!result)
4434 return;
4435 helper.yieldOutputs({result});
4436}
4437
4438ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4439 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4440 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4441 return parser.emitError(parser.getCurrentLocation(),
4442 "expected 'indexing_maps' attribute");
4443 result.addAttribute("indexing_maps", *indexingMapsAttr);
4444
4445 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4446 regionBuilder);
4447}
4448
4449void ContractOp::print(OpAsmPrinter &p) {
4450 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4452 p, getOperation(), getInputs(), getOutputs(),
4453 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4454}
4455
4456LogicalResult ContractOp::verify() {
4457 int iterationSpaceDims = -1;
4458 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4459 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4460 // access an input operand (so occurrence count can be at most 2) and
4461 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4462 SmallVector<size_t> inOccurrences;
4463 SmallVector<size_t> outOccurrences;
4464
4465 // A helper so that for each operand's affine_map and type we check that ...
4466 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4467 bool isInput) -> LogicalResult {
4468 // ... the affine_map is a projected permutation;
4469 if (!affineMap.isProjectedPermutation())
4470 return emitError("provided affine_map is not a projected permutation");
4471
4472 // ... the rank of the affine_map's results and corresponding type match;
4473 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4474 if (affineMap.getNumResults() != shapedType.getRank())
4475 return emitError("ranks of shaped operand and results of corresponding "
4476 "affine_map differ");
4477 } else if (affineMap.getNumResults() != 0) {
4478 return emitError("affine_map specifies shaped access while operand has "
4479 "non-shaped type");
4480 }
4481
4482 // ... the rank of the affine_map's domain is the same as those seen prior;
4483 if (iterationSpaceDims == -1) {
4484 iterationSpaceDims = affineMap.getNumDims();
4485 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4486 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4487 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4488 return emitError("iteration spaces of provided affine_maps differ");
4489 }
4490
4491 // ... update counts of dims used to access either an input or the output.
4492 for (AffineExpr affineExpr : affineMap.getResults()) {
4493 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4494 if (!affineDimExpr)
4495 llvm_unreachable("affine_map is a projected permutation");
4496
4497 if (isInput)
4498 inOccurrences[affineDimExpr.getPosition()] += 1;
4499 else
4500 outOccurrences[affineDimExpr.getPosition()] += 1;
4501 }
4502
4503 return success();
4504 };
4505
4506 for (auto &&[affineMap, operandType, isInput] :
4507 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4508 SmallVector<bool>{true, true, false})) {
4509 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4510 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4511 }
4512
4513 bool hasContractingDim = false;
4514 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4515 size_t inOccCount = inOccurrences[dimIndex];
4516 size_t outOccCount = outOccurrences[dimIndex];
4517
4518 // We have a contracting dim if and only if ...
4519 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4520
4521 if (inOccCount == 0 && outOccCount == 0)
4522 return emitError() << "iteration space dim at index " << dimIndex
4523 << " not used to access any operand";
4524
4525 // NB: We disallow a dim which occurs for only one input operand and not
4526 // for the output. In terms of einsum semantics such dims have a
4527 // sensible meaning - namely an additional reduction per each such dim.
4528 // By contrast, the ContractionOpInterface does not know about this
4529 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4530 // while vector.contract's verifier accepts dims of this kind many of
4531 // its lowerings give up on encountering these dims.
4532 // TODO: Remove following once we have comprehensive support for input-only
4533 // reduction dims, at both the linalg- and vector-dialect levels.
4534 if (inOccCount == 1 && outOccCount != 1)
4535 return emitError()
4536 << "iteration space dim at index " << dimIndex
4537 << " is neither a contracting dim nor of parallel iteration type";
4538 }
4539
4540 if (!hasContractingDim)
4541 return emitError("'indexing_maps' do not specify a contracting dimension");
4542
4543 return success();
4544}
4545
4546LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4547 return memref::foldMemRefCast(*this);
4548}
4549
4550void ContractOp::getEffects(
4551 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4552 &effects) {
4553 if (hasPureTensorSemantics())
4554 return;
4555 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4556}
4557
4558Speculation::Speculatability ContractOp::getSpeculatability() {
4559 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4560}
4561
4562//===----------------------------------------------------------------------===//
4563// Implementation of BatchMatmulOp
4564//===----------------------------------------------------------------------===//
4565SmallVector<AffineMap>
4566BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4567 AffineExpr d0, d1, d2, d3;
4568 SmallVector<AffineMap> indexingMaps;
4569 bindDims(context, d0, d1, d2, d3);
4570 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4571 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4572 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4573 return indexingMaps;
4574}
4575
4576bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4577 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4578 if (!maps)
4579 return false;
4580 if (maps.size() != 3)
4581 return false;
4582 auto positions = getAffineResultPositions(maps);
4583 if (failed(positions))
4584 return false;
4585 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4586 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4587 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4588}
4589
4590SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4591 return SmallVector<utils::IteratorType>{
4592 utils::IteratorType::parallel, utils::IteratorType::parallel,
4593 utils::IteratorType::parallel, utils::IteratorType::reduction};
4594}
4595
4596unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4597
4598std::string BatchMatmulOp::getLibraryCallName() {
4599 return generateLibraryCallName(getOperation());
4600}
4601
4602/// Check if the op has broadcast and/or transpose semantic. Returns true if
4603/// the user defined indexing maps are not equal to default map.
4604bool BatchMatmulOp::hasUserDefinedMaps() {
4605 SmallVector<AffineMap, 3> defaultMaps =
4606 getDefaultIndexingMaps(this->getContext());
4607 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4608 return defaultMaps != explicitMaps;
4609}
4610
4611/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4612/// broadcast map must include K dimension.
4613/// TODO: Strict inclusion of K dimension in the broadcast map is not
4614/// necessary for both input matrices simultaneously. We can relax this
4615/// condition to have K dimension for one input matrix map and infer the K
4616/// dimension for other input matrix map from the one already having K
4617/// dimension.
4618bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4619 assert(bcastMap.getNumResults() < 3 &&
4620 "Expected less than 3 result dim expr.");
4621 bool isValid = false;
4622 enum Indices { batchPos, mPos, nPos, kPos };
4623 if (bcastMap.getNumResults() == 1) {
4624 AffineExpr expr = bcastMap.getResult(0);
4625 isValid = expr.isFunctionOfDim(kPos);
4626 } else if (bcastMap.getNumResults() == 2) {
4627 AffineExpr expr0 = bcastMap.getResult(0);
4628 AffineExpr expr1 = bcastMap.getResult(1);
4629 isValid =
4630 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4631 expr0.isFunctionOfDim(mPos)) &&
4632 expr1.isFunctionOfDim(kPos))
4633 : ((expr0.isFunctionOfDim(batchPos) &&
4634 expr1.isFunctionOfDim(kPos)) ||
4635 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4636 }
4637 return isValid;
4638}
4639
4640void BatchMatmulOp::regionBuilder(
4641 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4642 function_ref<InFlightDiagnostic()> emitError) {
4643 if (emitError && block.getNumArguments() != 3) {
4644 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4645 << block.getNumArguments();
4646 return;
4647 }
4648 assert(block.getNumArguments() == 3 &&
4649 "BatchMatmulOp regionBuilder expects 3 args");
4650 RegionBuilderHelper helper(b, block);
4651 SmallVector<Value> yields;
4652
4653 TypeFn castVal = TypeFn::cast_signed;
4654 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4655 return attr.getName() == "cast";
4656 });
4657 if (castIter != attrs.end()) {
4658 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4659 castVal = attr.getValue();
4660 }
4661
4662 auto toType = block.getArgument(2).getType();
4663 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4664 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4665 Value mulVal =
4666 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4667 if (!castValA || !castValB || !mulVal)
4668 return;
4669 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4670 mulVal, emitError);
4671 if (!addVal)
4672 return;
4673 yields.push_back(addVal);
4674 helper.yieldOutputs(yields);
4675}
4676
4677ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4678 SmallVector<Attribute, 3> indexingMapsAttr;
4679 Attribute mapAttr;
4680 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4681 if (parser.parseEqual())
4682 return failure();
4683
4684 if (parser.parseLSquare())
4685 return failure();
4686
4687 do {
4688 if (parser.parseAttribute(mapAttr))
4689 return failure();
4690 if (!isa<AffineMapAttr>(mapAttr)) {
4691 return parser.emitError(parser.getCurrentLocation(),
4692 "expected affine map attribute");
4693 }
4694 indexingMapsAttr.push_back(mapAttr);
4695
4696 if (parser.parseOptionalComma())
4697 break;
4698 } while (true);
4699
4700 if (parser.parseRSquare())
4701 return failure();
4702 }
4703 // Initialize indexingMaps, if not supplied explicitly.
4704 if (indexingMapsAttr.empty()) {
4705 indexingMapsAttr = llvm::map_to_vector(
4706 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4707 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4708 }
4709 result.addAttribute("indexing_maps",
4710 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4711
4712 return ::parseNamedStructuredOp(parser, result,
4713 BatchMatmulOp::getNumRegionArgs(),
4714 BatchMatmulOp::getRegionBuilder());
4715}
4716
4717void BatchMatmulOp::print(OpAsmPrinter &p) {
4718 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4719 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4720 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4721 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4722 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4723
4724 std::array<StringRef, 3> elidedAttrs = {
4725 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4726 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4727 elidedAttrs);
4728}
4729
4730/// Verify the user defined indexing maps.
4731LogicalResult BatchMatmulOp::verify() {
4732 // Verification of pure batch_matmul is handled by
4733 // verifyStructuredOpInterface().
4734 if (!hasUserDefinedMaps())
4735 return success();
4736
4737 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4739 return failure();
4740 }
4741 return success();
4742}
4743
4744LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4745 SmallVectorImpl<OpFoldResult> &) {
4746 return memref::foldMemRefCast(*this);
4747}
4748
4749void BatchMatmulOp::getEffects(
4750 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4751 &effects) {
4752 if (hasPureTensorSemantics())
4753 return;
4754 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4755}
4756
4757Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4758 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4759}
4760
4761//===----------------------------------------------------------------------===//
4762// ElementwiseOp
4763//===----------------------------------------------------------------------===//
4764//
4765namespace {
4766struct ArityGroupAndKind {
4767 // The enum class {Unary, Binary, Ternary, ..}
4768 ElementwiseArityGroup arityGroup;
4769
4770 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4771 union Kind {
4772 UnaryFn unaryFn;
4773 BinaryFn binaryFn;
4774 TernaryFn ternaryFn;
4775 } kind;
4776};
4777
4778unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4779 return static_cast<unsigned>(arityGroup);
4780}
4781} // namespace
4782
4783static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4784 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4785 constexpr int lastBinary =
4786 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4787 constexpr int lastTernary =
4788 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4789
4790 int val = static_cast<int>(kind);
4791 ArityGroupAndKind result;
4792
4793 if (val < lastUnary) {
4794 result.arityGroup = ElementwiseArityGroup::Unary;
4795 result.kind.unaryFn = static_cast<UnaryFn>(val);
4796 return result;
4797 }
4798 if (val < lastBinary) {
4799 result.arityGroup = ElementwiseArityGroup::Binary;
4800 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4801 return result;
4802 }
4803 if (val >= lastTernary) {
4804 llvm_unreachable("unhandled ElementwiseFn");
4805 }
4806 result.arityGroup = ElementwiseArityGroup::Ternary;
4807 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4808 return result;
4809}
4810
4811SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4812 auto rank = getResultRank();
4813 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4814}
4815
4817ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4818 MLIRContext *context) {
4819 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4820 return SmallVector<AffineMap>(numMaps, map);
4821}
4822
4823ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4824 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4825 Attribute attr;
4826 mlir::linalg::ElementwiseKind elemwiseKindVal;
4827 if (parser.parseKeyword("kind") || parser.parseEqual())
4828 return failure();
4829
4830 if (succeeded(parser.parseAttribute(attr))) {
4831 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4832 if (!elemwiseKindAttr)
4833 return parser.emitError(parser.getCurrentLocation(),
4834 "expected ElementwiseKind attribute");
4835 elemwiseKindVal = elemwiseKindAttr.getValue();
4836 } else {
4837 return parser.emitError(parser.getCurrentLocation(),
4838 "expected operation 'kind' attribute");
4839 }
4840 result.addAttribute(
4841 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4842
4843 // Parse optional `indexing_maps`
4844 SmallVector<Attribute, 3> indexingMapsAttr;
4845 Attribute mapAttr;
4846 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4847 if (parser.parseEqual())
4848 return failure();
4849 if (parser.parseLSquare())
4850 return failure();
4851 do {
4852 if (parser.parseAttribute(mapAttr))
4853 return failure();
4854 if (!isa<AffineMapAttr>(mapAttr))
4855 return parser.emitError(parser.getCurrentLocation(),
4856 "expected affine map attribute");
4857 indexingMapsAttr.push_back(mapAttr);
4858 if (parser.parseOptionalComma())
4859 break;
4860 } while (true);
4861 if (parser.parseRSquare())
4862 return failure();
4863 }
4864 // At this stage of parsing the only way to infer number of region
4865 // args is through op kind, as input output tensors are not parsed yet.
4866 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4867 int numRegionArgs =
4868 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4869 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4870 ElementwiseOp::getRegionBuilder())) {
4871 return parser.emitError(parser.getCurrentLocation(),
4872 "unable to parse elemwise op");
4873 }
4874
4875 // Initialize indexingMaps, if not supplied explicitly.
4876 if (indexingMapsAttr.empty()) {
4877 // We need to infer the numDims of the indexing maps from the output
4878 // type which is already parsed by now.
4879 auto resultType = result.operands[result.operands.size() - 1].getType();
4880 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4881 if (!shapedType)
4882 return parser.emitError(parser.getCurrentLocation(),
4883 "return type needs to be shaped type");
4884 auto numDims = shapedType.getRank();
4885 indexingMapsAttr = llvm::map_to_vector(
4886 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4887 parser.getContext()),
4888 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4889 }
4890
4891 result.addAttribute("indexing_maps",
4892 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4893 return success();
4894}
4895
4896void ElementwiseOp::print(OpAsmPrinter &p) {
4897 p << " kind=";
4898 p.printAttribute(getKindAttr());
4899 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4900 "indexing_maps"};
4901 unsigned arity =
4902 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4903 unsigned numDims = getResultRank();
4904
4905 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4906 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4907 getContext()),
4908 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4909
4910 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4911 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4912
4913 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4914 elidedAttrs);
4915}
4916
4917/// Implements the block region builder for the ElementwiseOp. This is called by
4918/// 'fillStructuredOpRegion'.
4919void ElementwiseOp::regionBuilder(
4920 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4921 function_ref<InFlightDiagnostic()> emitError) {
4922 ElementwiseKind elemwiseKind;
4923 for (auto attr : attrs) {
4924 if (attr.getName() == b.getStringAttr("kind")) {
4925 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4926 assert(kindAttr && "op kind attribute incorrectly set");
4927 elemwiseKind = kindAttr.getValue();
4928 break;
4929 }
4930 }
4931
4932 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4933 auto arityGroup = groupAndKind.arityGroup;
4934 auto kind = groupAndKind.kind;
4935 if (emitError && block.getNumArguments() !=
4936 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4937 emitError() << "Elementwise regionBuilder expects "
4938 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4939 << block.getNumArguments();
4940 return;
4941 }
4942 assert(block.getNumArguments() ==
4943 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4944 && "Elementwise regionBuilder number of block args mismatch");
4945
4946 RegionBuilderHelper helper(b, block);
4947 SmallVector<Value> yields;
4948 Value result;
4949
4950 if (arityGroup == ElementwiseArityGroup::Unary) {
4951 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4952
4953 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4954 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4955 block.getArgument(1));
4956
4957 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4958 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4959 block.getArgument(1), block.getArgument(2));
4960
4961 } else {
4962 assert(false && "found unhandled category in elemwise");
4963 }
4964
4965 yields.push_back(result);
4966 helper.yieldOutputs(yields);
4967}
4968
4969LogicalResult ElementwiseOp::fold(FoldAdaptor,
4970 SmallVectorImpl<OpFoldResult> &) {
4971 return memref::foldMemRefCast(*this);
4972}
4973
4974void ElementwiseOp::getEffects(
4975 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4976 &effects) {
4977 if (hasPureTensorSemantics())
4978 return;
4979 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4980}
4981
4982Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4983 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4984}
4985
4986//===----------------------------------------------------------------------===//
4987// PackOp/UnPackOp Common
4988//===----------------------------------------------------------------------===//
4989
4990template <typename OpTy, typename>
4991SmallVector<int64_t>
4993 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4994 ? packOrUnPack.getDestType()
4995 : packOrUnPack.getSourceType();
4996 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4997 ? packOrUnPack.getSourceType()
4998 : packOrUnPack.getDestType();
5000 packedType.getShape().take_front(unpackedType.getRank()));
5001 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5003 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
5004 }
5005 return result;
5006}
5011
5012// Given the (potentially) updated packed type, `newPackedTy`, generates an
5013// updated mixed-tile-sizes attribute. A tile size is updated only
5014// when:
5015// * a dim from newPackedTy is static, and
5016// * the corresponding size from mixedTiles is still dynamic.
5017// Otherwise, the original tile size is preserved.
5018// Note - packed-type-dim and mixed-tile-size should always match!
5021 SmallVector<OpFoldResult> mixedTiles) {
5022 SmallVector<OpFoldResult> newMixedTileSizes;
5023 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5024 .getShape()
5025 .take_back(mixedTiles.size()),
5026 mixedTiles)) {
5027 int64_t dimSize = std::get<0>(it);
5028 if (dimSize == ShapedType::kDynamic) {
5029 newMixedTileSizes.push_back(std::get<1>(it));
5030 continue;
5031 }
5032
5033 // If the current result dim is static, update the dynamic mixed-size
5034 // (provided the original value is dynamic).
5035 OpFoldResult tile = std::get<1>(it);
5036 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5037 // Already a constant
5038 newMixedTileSizes.push_back(tile);
5039 } else {
5040 assert(getConstantIntValue(tile).value() == dimSize &&
5041 "tile size and dim size don't match!");
5042 newMixedTileSizes.push_back(
5043 (rewriter.getIntegerAttr(rewriter.getIndexType(), dimSize)));
5044 }
5045 }
5046
5047 return newMixedTileSizes;
5048}
5049
5050template <typename OpTy>
5051static LogicalResult
5053 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5054 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5055 "applies to only pack or unpack operations");
5056 int64_t destRank = op.getDestRank();
5057 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5058 for (auto dim : llvm::seq<int64_t>(0, destRank))
5059 reifiedReturnShapes[0][dim] =
5060 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5061 return success();
5062}
5063
5064template <typename OpTy>
5066 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5067 "applies to only pack or unpack operations");
5068 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5069 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5070 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5071 assert(tiles.size() == dimsToTile.size() &&
5072 "tiles must match indices of dimension to block");
5073 // bind the dimension `i` with the tile factor.
5074 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5075 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5076 return dimAndTileMapping;
5077}
5078
5079template <typename OpTy>
5081 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5082 "applies to only pack or unpack operations");
5083 Builder builder(op);
5084 SmallVector<OpFoldResult> mixedInnerTiles;
5085 unsigned dynamicValIndex = 0;
5086 for (int64_t staticTile : op.getStaticInnerTiles()) {
5087 if (ShapedType::isStatic(staticTile))
5088 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5089 else
5090 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5091 }
5092 return mixedInnerTiles;
5093}
5094
5095template <typename OpTy>
5097 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5098 "applies to only pack or unpack operations");
5099 SmallVector<Value> dynamicTiles;
5100 SmallVector<int64_t> staticTiles;
5101 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5102 return staticTiles;
5103}
5104
5105/// Returns true if `dimsPos` is invalid. It is invalid when:
5106/// a) It contains duplicate.
5107/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5108/// c) The number of elements in `dimsPos` is > than `rank`.
5110 size_t rank) {
5111 size_t dimsPosSize = dimsPos.size();
5112 if (dimsPosSize > rank)
5113 return true;
5114 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5115 if (dimsPosSize != uniqued.size())
5116 return true;
5117 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5118 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5119 });
5120}
5121
5122template <typename OpTy>
5123static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5124 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5125 "applies to only pack or unpack operations");
5126 Operation *op = packOrUnPack.getOperation();
5127
5128 // Return true if we have a zero-value tile.
5129 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5130 return llvm::any_of(tiles, [](OpFoldResult tile) {
5131 return isa<Attribute>(tile) && isZeroInteger(tile);
5132 });
5133 };
5134
5135 // Verify that the source and destination are ranked types.
5136 if (!packOrUnPack.getSourceType().hasRank() ||
5137 !packOrUnPack.getDestType().hasRank())
5138 return op->emitError("expected both source and destination to have rank");
5139
5140 // Verify that the Operation does not have mixed tensor/buffer semantics.
5141 if (!packOrUnPack.hasPureBufferSemantics() &&
5142 !packOrUnPack.hasPureTensorSemantics())
5143 return op->emitError("mixing tensor and buffer semantics is not allowed");
5144 const unsigned numResults = packOrUnPack.getNumResults();
5145 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5146 return op->emitError("expected 1 result, got ") << numResults;
5147 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5148 return op->emitError("expected 0 results, got ") << numResults;
5149
5150 // Verify tiles. Do not allow zero tiles.
5151 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5152 if (hasZeros(mixedTiles))
5153 return op->emitError("invalid zero tile factor");
5154
5155 // Verify inner_dims_pos and outer_dims_perm.
5156 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5157 ? packOrUnPack.getSourceType()
5158 : packOrUnPack.getDestType();
5159 size_t unpackedRank = unpackedType.getRank();
5160 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5161 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5162 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5163 return op->emitError("invalid inner_dims_pos vector");
5164 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5165 return op->emitError("invalid outer_dims_perm vector");
5166 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5167 return op->emitError("outer_dims_perm must be a permutation or empty");
5168
5169 // Tiling factors must be less than or equal to the input rank for pack (or
5170 // output rank for unpack), and must match the number of `inner_dims_pos`.
5171 if (mixedTiles.size() > unpackedRank) {
5172 return op->emitError("tiling factors must be less than or equal to the "
5173 "input rank for pack or output rank for unpack");
5174 }
5175 if (mixedTiles.size() != innerDimsPos.size()) {
5176 return op->emitError(
5177 "tiling factors must equal the number of dimensions to tile");
5178 }
5179
5180 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5181 ? packOrUnPack.getDestType()
5182 : packOrUnPack.getSourceType();
5183 size_t packedRank = packedType.getRank();
5184 // Require output rank to match input rank + number of blocking factors.
5185 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5186 if (expectedPackedRank != packedRank) {
5187 return op->emitError(
5188 "packed rank != (unpacked rank + num tiling factors), got ")
5189 << packedRank << " != " << expectedPackedRank;
5190 }
5191
5192 // Verify result shape is greater than the minimum expected
5193 // by the pack operation, and that the output shape
5194 // represents full tiles.
5195 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5196 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5197 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5198 for (auto it : llvm::enumerate(llvm::zip(
5199 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5200 int64_t dimSize = std::get<0>(it.value());
5201 if (Attribute attr =
5202 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5203 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5204 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5205 if (dimSize != staticTileSize)
5206 return op->emitError(
5207 "mismatch in inner tile sizes specified and shaped of "
5208 "tiled dimension in the packed type at index ")
5209 << it.index() << ": got " << dimSize << " != " << staticTileSize;
5210 } else if (!ShapedType::isDynamic(dimSize)) {
5211 return op->emitError("mismatch in inner tile sizes specified at index ")
5212 << it.index() << ": got static shape " << dimSize
5213 << " but dynamic tile size";
5214 }
5215 }
5216 if (failed(
5217 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5218 auto elementType = unpackedType.getElementType();
5219 Type expectedType, actualType;
5220 if (packOrUnPack.hasPureTensorSemantics()) {
5221 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5222 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5223 } else {
5224 expectedType = MemRefType::get(expectedPackedShape, elementType);
5225 actualType = MemRefType::get(packedType.getShape(), elementType);
5226 }
5227 return op->emitError("expected ")
5228 << expectedType << " for the packed domain value, got "
5229 << actualType;
5230 }
5231 return success();
5232}
5233
5234namespace {
5235/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5236/// various permutations to the op.
5237// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5238// these. These may or may not become true foldings / canonicalizations
5239// depending on how aggressive we want to be in automatically folding
5240// transposes.
5241struct PackOrUnPackTransposeResult {
5242 SmallVector<int64_t> innerDimsPos;
5243 SmallVector<OpFoldResult> innerTiles;
5244 SmallVector<int64_t> outerDimsPerm;
5245};
5246} // namespace
5247
5248template <typename OpTy>
5249static PackOrUnPackTransposeResult
5251 ArrayRef<int64_t> innerPermutation,
5252 ArrayRef<int64_t> outerPermutation) {
5253 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5254 "applies to only pack or unpack operations");
5255 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5256 "some permutation must be non-empty");
5257 PackOrUnPackTransposeResult metadata;
5258 metadata.innerDimsPos =
5259 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5260 metadata.innerTiles =
5261 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5262 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5263 ? packOrUnPackOp.getSourceRank()
5264 : packOrUnPackOp.getDestRank();
5265 metadata.outerDimsPerm =
5266 packOrUnPackOp.getOuterDimsPerm().empty()
5267 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5268 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5269 if (!innerPermutation.empty()) {
5270 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5271 isPermutationVector(innerPermutation) &&
5272 "invalid inner permutation");
5273 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5274 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5275 }
5276 if (!outerPermutation.empty()) {
5277 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5278 isPermutationVector(outerPermutation) &&
5279 "invalid outer permutation");
5280 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5281 }
5282 return metadata;
5283}
5284
5285//===----------------------------------------------------------------------===//
5286// PackOp
5287//===----------------------------------------------------------------------===//
5288
5289void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5290 if (!getResults().empty())
5291 setNameFn(getResult(), "pack");
5292}
5293
5294ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5295 OpAsmParser::UnresolvedOperand source, dest;
5298 SmallVector<Type> paddingValueType;
5299 SmallVector<int64_t> staticTiles;
5300 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5301 Type sourceType, destType, resultType;
5302
5303 if (parser.parseOperand(source))
5304 return failure();
5305
5306 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5307 if (parser.parseLParen() ||
5308 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5309 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5310 parser.parseRParen())
5311 return failure();
5312 }
5313
5314 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5315 if (parser.parseEqual())
5316 return failure();
5317
5318 SmallVector<int64_t> outerDimsPermVec;
5320 int64_t value;
5321 if (parser.parseInteger(value))
5322 return failure();
5323 outerDimsPermVec.push_back(value);
5324 return success();
5325 }))
5326 return failure();
5327 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5328 }
5329
5330 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5331 return failure();
5332
5333 SmallVector<int64_t> innerDimsPosVec;
5335 int64_t value;
5336 if (parser.parseInteger(value))
5337 return failure();
5338 innerDimsPosVec.push_back(value);
5339 return success();
5340 }))
5341 return failure();
5342 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5343
5344 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5345 return failure();
5346
5347 DenseI64ArrayAttr staticTilesAttr;
5348 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5349 return failure();
5350 for (auto val : staticTilesAttr.asArrayRef())
5351 staticTiles.push_back(val);
5352
5353 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5354 return failure();
5355
5356 if (parser.parseOptionalAttrDict(result.attributes))
5357 return failure();
5358
5359 if (parser.parseColon() || parser.parseType(sourceType))
5360 return failure();
5361
5362 bool hasArrow = succeeded(parser.parseOptionalArrow());
5363 if (hasArrow) {
5364 if (parser.parseType(destType))
5365 return failure();
5366 }
5367
5368 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5369 if (!hasArrow) {
5370 return parser.emitError(parser.getCurrentLocation(),
5371 "pack/unpack requires '->' and destination type");
5372 }
5373
5374 if (!isMemRef)
5375 resultType = destType;
5376
5377 if (parser.resolveOperand(source, sourceType, result.operands) ||
5378 parser.resolveOperand(dest, destType, result.operands))
5379 return failure();
5380
5381 if (!paddingValue.empty() &&
5382 parser.resolveOperands(paddingValue, paddingValueType[0],
5383 result.operands))
5384 return failure();
5385
5386 if (!dynamicTiles.empty() &&
5387 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5388 result.operands))
5389 return failure();
5390
5391 result.addAttribute("static_inner_tiles",
5392 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5393 result.addAttribute("inner_dims_pos", innerDimsPos);
5394 if (outerDimsPerm)
5395 result.addAttribute("outer_dims_perm", outerDimsPerm);
5396
5397 SmallVector<int32_t> segmentSizes = {
5398 1, 1, static_cast<int32_t>(paddingValue.size()),
5399 static_cast<int32_t>(dynamicTiles.size())};
5400 result.addAttribute("operandSegmentSizes",
5401 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5402
5403 if (!isMemRef)
5404 result.addTypes(resultType);
5405
5406 return success();
5407}
5408
5409void PackOp::print(OpAsmPrinter &p) {
5410 p << " " << getSource();
5411
5412 if (getPaddingValue()) {
5413 p << " padding_value(" << getPaddingValue() << " : "
5414 << getPaddingValue().getType() << ")";
5415 }
5416
5417 if (!getOuterDimsPerm().empty()) {
5418 p << " outer_dims_perm = [";
5419 llvm::interleaveComma(getOuterDimsPerm(), p);
5420 p << "]";
5421 }
5422
5423 p << " inner_dims_pos = [";
5424 llvm::interleaveComma(getInnerDimsPos(), p);
5425 p << "]";
5426
5427 p << " inner_tiles = ";
5428 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5429
5430 p << " into " << getDest();
5431
5432 p.printOptionalAttrDict((*this)->getAttrs(),
5433 {"static_inner_tiles", "inner_dims_pos",
5434 "outer_dims_perm", "operandSegmentSizes"});
5435
5436 p << " : " << getSource().getType();
5437 p << " -> " << getDest().getType();
5438}
5439
5440void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5441 Value dest, ArrayRef<int64_t> innerDimsPos,
5442 ArrayRef<OpFoldResult> innerTiles,
5443 std::optional<Value> paddingValue,
5444 ArrayRef<int64_t> outerDimsPerm) {
5445 assert(innerDimsPos.size() == innerTiles.size() &&
5446 "number of tile sizes specified must match the specified number of "
5447 "original dimensions to be tiled");
5448 SmallVector<int64_t> staticTileSizes;
5449 SmallVector<Value> dynamicTileSizes;
5450 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5451 build(builder, state, dest.getType(), source, dest,
5452 paddingValue ? *paddingValue : nullptr,
5453 outerDimsPerm.empty() ? nullptr
5454 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5455 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5456 builder.getDenseI64ArrayAttr(staticTileSizes));
5457}
5458
5459LogicalResult
5460PackOp::reifyResultShapes(OpBuilder &builder,
5461 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5462 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5463}
5464
5465DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5466 return getDimAndTileMappingImpl(*this);
5467}
5468
5469SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5470 return getMixedTilesImpl(*this);
5471}
5472
5473SmallVector<int64_t> PackOp::getStaticTiles() {
5474 return getStaticTilesImpl(*this);
5475}
5476
5477ArrayRef<int64_t> PackOp::getAllOuterDims() {
5478 ShapedType inputType = getSourceType();
5479 int64_t inputRank = inputType.getRank();
5480 return getDestType().getShape().take_front(inputRank);
5481}
5482
5483SmallVector<int64_t> PackOp::getTiledOuterDims() {
5484 auto innerDimsPos = getInnerDimsPos();
5485 SmallVector<int64_t> outerDims(getAllOuterDims());
5486 SmallVector<int64_t> res;
5487
5488 // Recover the original order of the outer dims.
5489 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5490 invertPermutationVector(outerDimPermInv);
5491 if (!outerDimPermInv.empty())
5492 applyPermutationToVector(outerDims, outerDimPermInv);
5493
5494 // Collect the outer dims corresponding to the tilled inner dims.
5495 for (auto index : innerDimsPos)
5496 res.push_back(outerDims[index]);
5497
5498 return res;
5499}
5500
5501bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5502 ArrayRef<int64_t> innerDimsPos,
5503 ArrayRef<int64_t> outputShape,
5504 ArrayRef<int64_t> outerDimsPerm,
5505 ArrayRef<OpFoldResult> innerTiles) {
5506 SmallVector<int64_t> outputTileSizes(
5507 outputShape.take_front(inputShape.size()));
5508 if (!outerDimsPerm.empty()) {
5509 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5510 "expected output and outer_dims_perm to have same size");
5511 applyPermutationToVector(outputTileSizes,
5512 invertPermutationVector(outerDimsPerm));
5513 }
5514 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5515 if (ShapedType::isDynamic(inputShape[pos]))
5516 continue;
5517 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5518 if (!constantTile) {
5519 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5520 (inputShape[pos] % outputTileSizes[pos] != 0))
5521 return true;
5522 } else {
5523 assert(*constantTile != 0 && "static tile size can't be zero");
5524 if (inputShape[pos] % (*constantTile) != 0) {
5525 return true;
5526 }
5527 }
5528 }
5529 return false;
5530}
5531
5532bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5533 ArrayRef<int64_t> innerDimsPos,
5534 ArrayRef<int64_t> outputShape,
5535 ArrayRef<int64_t> outerDimsPerm,
5536 ArrayRef<OpFoldResult> innerTiles) {
5537 SmallVector<int64_t> outputTileSizes(
5538 outputShape.take_front(inputShape.size()));
5539 if (!outerDimsPerm.empty()) {
5540 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5541 "expected output and outer_dims_perm to have same size");
5542 applyPermutationToVector(outputTileSizes,
5543 invertPermutationVector(outerDimsPerm));
5544 }
5545 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5546 if (ShapedType::isDynamic(inputShape[pos]) ||
5547 ShapedType::isDynamic(outputTileSizes[pos]))
5548 return true;
5549 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5550 if (!constantTile)
5551 return true;
5552 assert(*constantTile != 0 && "static tile size can't be zero");
5553 if (inputShape[pos] % (*constantTile) != 0)
5554 return true;
5555 }
5556 return false;
5557}
5558
5559LogicalResult PackOp::verify() {
5561 return failure();
5562
5563 // Verify padding value, and bail out if the tile does not divide the
5564 // dimension fully. In the case of dynamic tile factors or dimensions, having
5565 // a partial tile is undefined behavior.
5566 auto paddingValue = getPaddingValue();
5567 if (paddingValue &&
5568 paddingValue.getType() != getSourceType().getElementType()) {
5569 return emitOpError("expected padding_value has ")
5570 << getSourceType().getElementType()
5571 << " but got: " << paddingValue.getType();
5572 }
5573
5574 if (!paddingValue &&
5575 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5576 getDestType().getShape(), getOuterDimsPerm(),
5577 getMixedTiles())) {
5578 return emitOpError(
5579 "invalid tile factor or output size provided. Only full tiles are "
5580 "supported when padding_value is not set");
5581 }
5582 return success();
5583}
5584
5585/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5586/// Value's to kDynamic, even if they are arith.constant values.
5587static SmallVector<int64_t>
5590 for (auto o : ofrs) {
5591 // Have to do this first, as getConstantIntValue special-cases constants.
5592 if (llvm::dyn_cast_if_present<Value>(o))
5593 result.push_back(ShapedType::kDynamic);
5594 else
5595 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5596 }
5597 return result;
5598}
5599
5600SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5601 ArrayRef<int64_t> innerTileSizes,
5602 ArrayRef<int64_t> innerDimsPos,
5603 ArrayRef<int64_t> outerDimsPerm) {
5604 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5605 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5606 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5607 continue;
5608 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5609 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5610 continue;
5611 }
5612 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5613 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5614 }
5615
5616 // Swap tile loops if outer_dims_perm is available.
5617 if (!outerDimsPerm.empty())
5618 applyPermutationToVector(resultShape, outerDimsPerm);
5619
5620 // Append the inner tile dimensions.
5621 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5622 return resultShape;
5623}
5624
5625SmallVector<OpFoldResult> PackOp::getResultShape(
5626 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5627 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5628 ArrayRef<int64_t> outerDimsPerm) {
5629 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5630
5631 AffineExpr s0, s1;
5632 bindSymbols(builder.getContext(), s0, s1);
5633 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5634 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5635 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5636 builder, loc, ceilDivExpr,
5637 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5638 }
5639 if (!outerDimsPerm.empty())
5640 applyPermutationToVector(resultDims, outerDimsPerm);
5641 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5642
5643 SmallVector<int64_t> resultTypeShape =
5644 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5645 asShapeWithAnyValueAsDynamic(innerTileSizes),
5646 innerDimsPos, outerDimsPerm);
5647
5648 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5649 // result type shape says it's a dynamic dim. This is needed as callers may
5650 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5651 // dynamic dims returned by that.
5652 for (unsigned i = 0; i < resultDims.size(); ++i) {
5653 if (ShapedType::isStatic(resultTypeShape[i]))
5654 continue;
5655 resultDims[i] =
5656 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5657 }
5658
5659 return resultDims;
5660}
5661
5662RankedTensorType PackOp::inferPackedTensorType(
5663 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5664 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5665 SmallVector<int64_t> resultShape = inferPackedShape(
5666 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5667 return RankedTensorType::get(resultShape, sourceType.getElementType());
5668}
5669
5670MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5671 ArrayRef<int64_t> innerTileSizes,
5672 ArrayRef<int64_t> innerDimsPos,
5673 ArrayRef<int64_t> outerDimsPerm) {
5674 SmallVector<int64_t> resultShape = inferPackedShape(
5675 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5676 return MemRefType::get(resultShape, sourceType.getElementType());
5677}
5678
5679Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5680 ArrayRef<OpFoldResult> innerTileSizes,
5681 ArrayRef<int64_t> innerDimsPos,
5682 ArrayRef<int64_t> outerDimsPerm) {
5683 AffineExpr dim0, dim1;
5684 bindDims(b.getContext(), dim0, dim1);
5685 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5686 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5687 {v1, v2});
5688 };
5689
5690 SmallVector<OpFoldResult> mixedSizes;
5691 for (auto [index, value] : llvm::enumerate(
5692 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5693 if (ShapedType::isDynamic(value))
5694 mixedSizes.push_back(
5695 tensor::DimOp::create(b, loc, source, index).getResult());
5696 else
5697 mixedSizes.push_back(b.getIndexAttr(value));
5698 }
5699 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5700 int64_t dimPos = std::get<0>(it);
5701 OpFoldResult tileSize = std::get<1>(it);
5702 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5703 }
5704 if (!outerDimsPerm.empty())
5705 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5706
5707 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5708 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5709 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5710}
5711
5712PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5713 ArrayRef<int64_t> innerPermutation,
5714 ArrayRef<int64_t> outerPermutation) {
5715 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5716 *this, innerPermutation, outerPermutation);
5717 Value transposedDest =
5718 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5719 metadata.innerDimsPos, metadata.outerDimsPerm);
5720 return PackOp::create(b, loc, getSource(), transposedDest,
5721 metadata.innerDimsPos, metadata.innerTiles,
5722 getPaddingValue(), metadata.outerDimsPerm);
5723}
5724
5725template <typename OpTy>
5728 &effects) {
5729 // No memory effects for pure tensor semantics
5730 if (op.hasPureTensorSemantics())
5731 return;
5732
5733 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5734 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5735 continue;
5736
5737 if (&opOperand == &op.getSourceMutable()) {
5738 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5739 /*effectOnFullRegion=*/true,
5741 } else if (&opOperand == &op.getDestMutable()) {
5742 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5743 /*effectOnFullRegion=*/true,
5745 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5746 /*effectOnFullRegion=*/true,
5748 }
5749 }
5750}
5751
5752void PackOp::getEffects(
5754 &effects) {
5755 getPackUnPackEffectsImpl(*this, effects);
5756}
5757
5758void UnPackOp::getEffects(
5760 &effects) {
5761 getPackUnPackEffectsImpl(*this, effects);
5762}
5763
5764/// Returns true if the tiles and the tiled dims are constant.
5765template <typename OpTy>
5767 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5768 "applies to only pack or unpack operations");
5769 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5770 ? op.getDestType()
5771 : op.getSourceType();
5772 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5773 for (auto [dimDest, tile] : llvm::zip(
5774 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5775 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5776 if (!constTileSize || ShapedType::isDynamic(dimDest))
5777 return false;
5778 }
5779 return true;
5780}
5781
5782Speculation::Speculatability PackOp::getSpeculatability() {
5783 if (!hasPureTensorSemantics())
5785 if (getPaddingValue())
5787
5788 // The verifier rejects already operations if we can statically prove that the
5789 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5790 // to have constant tiles and tiled inner dimensions.
5793
5795}
5796
5797// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5798// dimensions for pack and unpack.
5799static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5800 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5801 return false;
5802 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5803 return true;
5804 // Outer dims permutation is optional.
5805 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5806 // identity permutation.
5807 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5808 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5809}
5810
5811// Return true if pack and unpack have the same tiles.
5812// Same SSA values or same integer constants.
5813static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5814 auto packTiles = packOp.getMixedTiles();
5815 auto unPackTiles = unPackOp.getMixedTiles();
5816 if (packTiles.size() != unPackTiles.size())
5817 return false;
5818 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5819 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5820 return false;
5821 }
5822 return true;
5823}
5824
5825/// Returns true if the pack op does not need a padding value.
5826static bool paddingIsNotNeeded(PackOp op) {
5827 auto srcType = op.getSourceType();
5828 if (llvm::any_of(op.getInnerDimsPos(),
5829 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5830 return false;
5831 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5832 return false;
5833 return !PackOp::requirePaddingValue(
5834 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5835 op.getOuterDimsPerm(), op.getMixedTiles());
5836}
5837
5838/// Returns true if the `srcShape` or `destShape` is different from the one in
5839/// `packOp` and populates each with the inferred static shape.
5840static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5841 SmallVectorImpl<int64_t> &destShape) {
5842 bool changeNeeded = false;
5843 srcShape.assign(packOp.getSourceType().getShape().begin(),
5844 packOp.getSourceType().getShape().end());
5845 destShape.assign(packOp.getDestType().getShape().begin(),
5846 packOp.getDestType().getShape().end());
5847 llvm::SmallSetVector<int64_t, 4> innerDims;
5848 innerDims.insert_range(packOp.getInnerDimsPos());
5849 SmallVector<int64_t> inverseOuterDimsPerm;
5850 if (!packOp.getOuterDimsPerm().empty())
5851 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5852 int srcRank = packOp.getSourceRank();
5853 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5854 if (innerDims.contains(i))
5855 continue;
5856 int64_t srcPos = i;
5857 int64_t destPos = i;
5858 if (!inverseOuterDimsPerm.empty())
5859 destPos = inverseOuterDimsPerm[srcPos];
5860 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5861 ShapedType::isDynamic(destShape[destPos])) {
5862 continue;
5863 }
5864 int64_t size = srcShape[srcPos];
5865 if (ShapedType::isDynamic(size))
5866 size = destShape[destPos];
5867 srcShape[srcPos] = size;
5868 destShape[destPos] = size;
5869 changeNeeded = true;
5870 }
5871 return changeNeeded;
5872}
5873
5874LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5875 // TODO: Support Memref PackOp. Temporarily return failure.
5876 if (!packOp.hasPureTensorSemantics())
5877 return failure();
5878
5879 // Fold an pack(unpack(x)) to x.
5880 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5881 if (unPackOp.getSourceType() == packOp.getDestType() &&
5882 !packOp.getPaddingValue() &&
5883 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5884 haveSameTiles(packOp, unPackOp)) {
5885 rewriter.replaceOp(packOp, unPackOp.getSource());
5886 return success();
5887 }
5888 }
5889
5890 // Fold optional PaddingValue operand away if padding is not needed.
5891 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5892 rewriter.startOpModification(packOp);
5893 packOp.getPaddingValueMutable().clear();
5894 rewriter.finalizeOpModification(packOp);
5895 return success();
5896 }
5897
5898 // Insert tensor.cast ops if static shape inference is available..
5899 SmallVector<int64_t> srcShape, destShape;
5900 if (inferStaticShape(packOp, srcShape, destShape)) {
5901 Location loc = packOp.getLoc();
5902 Value source = packOp.getSource();
5903 if (srcShape != packOp.getSourceType().getShape()) {
5904 auto newSrcType = packOp.getSourceType().clone(srcShape);
5905 source =
5906 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5907 }
5908 Value dest = packOp.getDest();
5909 ShapedType originalResultType = packOp.getDestType();
5910 bool needUpdateDestType = (destShape != originalResultType.getShape());
5911 if (needUpdateDestType) {
5912 auto newDestType = packOp.getDestType().clone(destShape);
5913 dest =
5914 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5915 }
5916 rewriter.modifyOpInPlace(packOp, [&] {
5917 packOp.getSourceMutable().assign(source);
5918 packOp.getDestMutable().assign(dest);
5919 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5920 });
5921 // Insert a cast if needed
5922 if (needUpdateDestType) {
5923 rewriter.setInsertionPointAfter(packOp);
5924 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5925 packOp.getResult());
5926 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5927 }
5928 return success();
5929 }
5930
5931 return failure();
5932}
5933
5934template <typename PackOrUnpackOp>
5935static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5936 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5937 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5938 "Function meant for pack/unpack");
5939 // This is a pad if packing only adds ones and we don't transpose dimensions.
5940
5941 // Check that we are not transposing any dimensions.
5942 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5943 int64_t numPackedDims = innerDimsPos.size();
5944 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5945 if (orderedDims != innerDimsPos) {
5946 // Dimensions don't happen in order.
5947 return false;
5948 }
5949
5950 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5951 int64_t packedRank = packedTensorType.getRank();
5952 // At this point we know that we are taking numPackedDims outer
5953 // dimensions and pushing them all the way as the inner most dimensions.
5954 // What's left on the outer most dimensions is, in this order:
5955 // - the factor of the packed dimensions, then
5956 // - the untouched dimensions
5957 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5958 // if all the dimensions that bubble outerward are ones.
5959 // Therefore check that all the dimensions but the numPackedDims inner most
5960 // ones are ones.
5961 return llvm::all_of(
5962 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5963 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5964}
5965
5966bool PackOp::isLikePad() {
5967 auto packedTensorType =
5968 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5969 return isLikePadUnPad(*this, packedTensorType);
5970}
5971
5972::mlir::LogicalResult
5973PackOp::fold(FoldAdaptor adaptor,
5975 if (!hasPureTensorSemantics())
5976 return failure();
5977 std::optional<Attribute> paddingValue;
5978 if (auto pad = adaptor.getPaddingValue())
5979 paddingValue = pad;
5980 if (OpFoldResult reshapedSource = reshapeConstantSource(
5981 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5982 cast<TensorType>(getDestType()), paddingValue)) {
5983 results.push_back(reshapedSource);
5984 return success();
5985 }
5986 return failure();
5987}
5988
5989/// Folds a tensor.cast op into a consuming PackOp op if the
5990/// `tensor.cast` has source that is more static than the consuming op.
5991///
5992/// Example:
5993/// ```mlir
5994/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5995/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5996/// ```
5997///
5998/// folds into:
5999///
6000/// ```mlir
6001/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
6002/// ```
6005
6006 LogicalResult matchAndRewrite(PackOp op,
6007 PatternRewriter &rewriter) const override {
6008 // TODO: Support Memref PackOp. Temporarily return failure.
6009 if (!op.hasPureTensorSemantics())
6010 return failure();
6011
6013 return failure();
6014
6015 SmallVector<Type> newResultTypes(op->getResultTypes());
6016 SmallVector<Value> newOperands =
6018
6019 // Get the updated mixed-tile-sizes attribute.
6020 SmallVector<OpFoldResult> newMixedTileSizes =
6021 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6022 if (llvm::any_of(newMixedTileSizes, isZeroInteger))
6023 return failure();
6024
6025 // Clone op.
6026 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6027 // this point. However, in practice, we use them for things that we'd like
6028 // to preserve. Implement a better abstraction.
6029 PackOp newOp =
6030 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6031 op.getInnerDimsPos(), newMixedTileSizes,
6032 op.getPaddingValue(), op.getOuterDimsPerm());
6033 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6034
6035 // Replace op.
6036 Value oldResult = op.getResult();
6037 Value newResult = newOp.getResult();
6039 (newResult.getType() != oldResult.getType())
6040 ? tensor::CastOp::create(rewriter, op->getLoc(),
6041 oldResult.getType(), newResult)
6042 : newResult;
6043
6044 rewriter.replaceOp(op, {replacement});
6045
6046 return success();
6047 }
6048};
6049
6050//===----------------------------------------------------------------------===//
6051// UnPackOp
6052//===----------------------------------------------------------------------===//
6053
6054void UnPackOp::getAsmResultNames(
6055 function_ref<void(Value, StringRef)> setNameFn) {
6056 if (!getResults().empty())
6057 setNameFn(getResult(), "unpack");
6058}
6059
6060// Custom parser for UnPackOp that handles the memref/tensor case distinction
6061ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6062 OpAsmParser::UnresolvedOperand source, dest;
6064 SmallVector<int64_t> staticTiles;
6065 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6066 Type sourceType, destType, resultType;
6067
6068 if (parser.parseOperand(source))
6069 return failure();
6070
6071 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6072 if (parser.parseEqual())
6073 return failure();
6074
6075 SmallVector<int64_t> outerDimsPermVec;
6077 int64_t value;
6078 if (parser.parseInteger(value))
6079 return failure();
6080 outerDimsPermVec.push_back(value);
6081 return success();
6082 }))
6083 return failure();
6084 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6085 }
6086
6087 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6088 return failure();
6089
6090 SmallVector<int64_t> innerDimsPosVec;
6092 int64_t value;
6093 if (parser.parseInteger(value))
6094 return failure();
6095 innerDimsPosVec.push_back(value);
6096 return success();
6097 }))
6098 return failure();
6099 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6100
6101 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6102 return failure();
6103
6104 DenseI64ArrayAttr staticTilesAttr;
6105 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6106 return failure();
6107 for (auto val : staticTilesAttr.asArrayRef())
6108 staticTiles.push_back(val);
6109
6110 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6111 return failure();
6112
6113 if (parser.parseOptionalAttrDict(result.attributes))
6114 return failure();
6115
6116 if (parser.parseColon() || parser.parseType(sourceType))
6117 return failure();
6118
6119 bool hasArrow = succeeded(parser.parseOptionalArrow());
6120 if (hasArrow) {
6121 if (parser.parseType(destType))
6122 return failure();
6123 }
6124
6125 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6126 if (!hasArrow) {
6127 return parser.emitError(parser.getCurrentLocation(),
6128 "pack/unpack requires '->' and destination type");
6129 }
6130
6131 if (!isMemRef)
6132 resultType = destType;
6133
6134 if (parser.resolveOperand(source, sourceType, result.operands) ||
6135 parser.resolveOperand(dest, destType, result.operands))
6136 return failure();
6137
6138 if (!dynamicTiles.empty() &&
6139 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6140 result.operands))
6141 return failure();
6142
6143 result.addAttribute("static_inner_tiles",
6144 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6145 result.addAttribute("inner_dims_pos", innerDimsPos);
6146 if (outerDimsPerm)
6147 result.addAttribute("outer_dims_perm", outerDimsPerm);
6148
6149 SmallVector<int32_t> segmentSizes = {
6150 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6151 result.addAttribute("operandSegmentSizes",
6152 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6153
6154 if (!isMemRef)
6155 result.addTypes(resultType);
6156
6157 return success();
6158}
6159
6160void UnPackOp::print(OpAsmPrinter &p) {
6161 p << " " << getSource();
6162
6163 if (!getOuterDimsPerm().empty()) {
6164 p << " outer_dims_perm = [";
6165 llvm::interleaveComma(getOuterDimsPerm(), p);
6166 p << "]";
6167 }
6168
6169 p << " inner_dims_pos = [";
6170 llvm::interleaveComma(getInnerDimsPos(), p);
6171 p << "]";
6172
6173 p << " inner_tiles = ";
6174 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6175
6176 p << " into " << getDest();
6177
6178 p.printOptionalAttrDict((*this)->getAttrs(),
6179 {"static_inner_tiles", "inner_dims_pos",
6180 "outer_dims_perm", "operandSegmentSizes"});
6181
6182 p << " : " << getSource().getType();
6183 p << " -> " << getDest().getType();
6184}
6185
6186LogicalResult
6187UnPackOp::reifyResultShapes(OpBuilder &builder,
6188 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6189 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6190}
6191
6192DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6193 return getDimAndTileMappingImpl(*this);
6194}
6195
6196SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6197 return getMixedTilesImpl(*this);
6198}
6199
6200SmallVector<int64_t> UnPackOp::getStaticTiles() {
6201 return getStaticTilesImpl(*this);
6202}
6203
6204ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6205 ShapedType destType = getDestType();
6206 int64_t destRank = destType.getRank();
6207 return getSourceType().getShape().take_front(destRank);
6208}
6209
6210SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6211 auto innerDimsPos = getInnerDimsPos();
6212 SmallVector<int64_t> outerDims(getAllOuterDims());
6213 SmallVector<int64_t> res;
6214
6215 // Recover the original order of the outer dims.
6216 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6217 invertPermutationVector(outerDimPermInv);
6218 if (!outerDimPermInv.empty())
6219 applyPermutationToVector(outerDims, outerDimPermInv);
6220
6221 // Collect the outer dims corresponding to the tilled inner dims.
6222 for (auto index : innerDimsPos)
6223 res.push_back(outerDims[index]);
6224
6225 return res;
6226}
6227
6228LogicalResult UnPackOp::verify() {
6229 return commonVerifierPackAndUnPackOp(*this);
6230}
6231
6232Speculation::Speculatability UnPackOp::getSpeculatability() {
6233 if (!hasPureTensorSemantics())
6235 // See PackOp::getSpeculatability.
6238
6240}
6241
6242void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6243 Value dest, ArrayRef<int64_t> innerDimsPos,
6244 ArrayRef<OpFoldResult> innerTiles,
6245 ArrayRef<int64_t> outerDimsPerm) {
6246 assert(innerDimsPos.size() == innerTiles.size() &&
6247 "number of tile sizes specified must match the specified number of "
6248 "original dimensions to be tiled");
6249 SmallVector<int64_t> staticTileSizes;
6250 SmallVector<Value> dynamicTileSizes;
6251 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6252 build(builder, state, dest.getType(), source, dest,
6253 outerDimsPerm.empty() ? nullptr
6254 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6255 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6256 builder.getDenseI64ArrayAttr(staticTileSizes));
6257}
6258
6259Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6260 Value source,
6261 ArrayRef<OpFoldResult> innerTileSizes,
6262 ArrayRef<int64_t> innerDimsPos,
6263 ArrayRef<int64_t> outerDimsPerm) {
6264 AffineExpr sym0, sym1;
6265 bindSymbols(b.getContext(), sym0, sym1);
6266 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6267 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6268 };
6269
6270 SmallVector<OpFoldResult> mixedSizes;
6271 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6272 for (auto i :
6273 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6274 if (srcType.isDynamicDim(i))
6275 mixedSizes.push_back(
6276 tensor::DimOp::create(b, loc, source, i).getResult());
6277 else
6278 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6279 }
6280 if (!outerDimsPerm.empty()) {
6282 mixedSizes, invertPermutationVector(outerDimsPerm));
6283 }
6284
6285 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6286 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6287
6288 auto elemType = srcType.getElementType();
6289 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6290}
6291
6292UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6293 Value transposedSource,
6294 ArrayRef<int64_t> innerPermutation,
6295 ArrayRef<int64_t> outerPermutation) {
6296 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6297 *this, innerPermutation, outerPermutation);
6298 return UnPackOp::create(b, loc, transposedSource, getDest(),
6299 metadata.innerDimsPos, metadata.innerTiles,
6300 metadata.outerDimsPerm);
6301}
6302
6303/// Returns true if the `srcShape` or `destShape` is different from the one in
6304/// `op` and populates each with the inferred static shape.
6305static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6306 SmallVectorImpl<int64_t> &destShape) {
6307 bool changeNeeded = false;
6308 srcShape.assign(op.getSourceType().getShape().begin(),
6309 op.getSourceType().getShape().end());
6310 destShape.assign(op.getDestType().getShape().begin(),
6311 op.getDestType().getShape().end());
6312 llvm::SmallSetVector<int64_t, 4> innerDims;
6313 innerDims.insert_range(op.getInnerDimsPos());
6314 SmallVector<int64_t> inverseOuterDimsPerm;
6315 if (!op.getOuterDimsPerm().empty())
6316 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6317 int destRank = op.getDestRank();
6318 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6319 if (innerDims.contains(i))
6320 continue;
6321 int64_t srcPos = i;
6322 int64_t destPos = i;
6323 if (!inverseOuterDimsPerm.empty())
6324 srcPos = inverseOuterDimsPerm[destPos];
6325 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6326 ShapedType::isDynamic(destShape[destPos])) {
6327 continue;
6328 }
6329 int64_t size = srcShape[srcPos];
6330 if (ShapedType::isDynamic(size))
6331 size = destShape[destPos];
6332 srcShape[srcPos] = size;
6333 destShape[destPos] = size;
6334 changeNeeded = true;
6335 }
6336 return changeNeeded;
6337}
6338
6339LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6340 PatternRewriter &rewriter) {
6341 // TODO: Support Memref UnPackOp. Temporarily return failure.
6342 if (!unPackOp.hasPureTensorSemantics())
6343 return failure();
6344
6345 /// unpack(pack(x)) -> x
6346 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6347 if (packOp.getSourceType() != unPackOp.getDestType())
6348 return failure();
6349 if (packOp.getPaddingValue() ||
6350 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6351 !haveSameTiles(packOp, unPackOp))
6352 return failure();
6353 rewriter.replaceOp(unPackOp, packOp.getSource());
6354 return success();
6355 }
6356 /// unpack(destinationStyleOp(x)) -> unpack(x)
6357 if (auto dstStyleOp =
6358 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6359 auto destValue = cast<OpResult>(unPackOp.getDest());
6360 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6361 rewriter.modifyOpInPlace(unPackOp,
6362 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6363 return success();
6364 }
6365 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6366 if (unPackOp->hasOneUse()) {
6367 auto extractSliceUser =
6368 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6369 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6370 OpBuilder::InsertionGuard g(rewriter);
6371 rewriter.setInsertionPoint(unPackOp);
6372 auto newDest = tensor::ExtractSliceOp::create(
6373 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6374 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6375 extractSliceUser.getMixedStrides());
6376 rewriter.modifyOpInPlace(unPackOp, [&]() {
6377 unPackOp.setDpsInitOperand(0, newDest);
6378 unPackOp.getResult().setType(newDest.getType());
6379 });
6380 rewriter.replaceOp(extractSliceUser, unPackOp);
6381 return success();
6382 }
6383 }
6384
6385 // Insert tensor.cast ops if static shape inference is available..
6386 SmallVector<int64_t> srcShape, destShape;
6387 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6388 Location loc = unPackOp.getLoc();
6389 Value source = unPackOp.getSource();
6390 if (srcShape != unPackOp.getSourceType().getShape()) {
6391 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6392 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6393 unPackOp.getSource());
6394 }
6395 Value dest = unPackOp.getDest();
6396 if (destShape != unPackOp.getDestType().getShape()) {
6397 auto newDestType = unPackOp.getDestType().clone(destShape);
6398 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6399 unPackOp.getDest());
6400 }
6401 UnPackOp newOp = UnPackOp::create(
6402 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6403 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6404 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6405 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6406 return success();
6407 }
6408
6409 return failure();
6410}
6411
6412bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6413 // Rank-reduced folding is not supported.
6414 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6415 return false;
6416 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6417 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6418 return false;
6419 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6420 SmallVector<int64_t> outerShapeWithoutTranspose =
6422 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6423 for (auto [pos, tileSize] :
6424 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6425 areOuterDimsTiled[pos] = true;
6426 if (unpackedTypeAfterFold.isDynamicDim(pos))
6427 return false;
6428 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6429 return false;
6430 if (ShapedType::isDynamic(tileSize))
6431 return false;
6432 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6433 unpackedTypeAfterFold.getDimSize(pos);
6434 if (paddingSize >= tileSize)
6435 return false;
6436 }
6437 // extract_slice must not affect dimensions that are not being unpacked
6438 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6439 if (areOuterDimsTiled[pos])
6440 continue;
6441 int64_t dim = outerShapeWithoutTranspose[pos];
6442 if (ShapedType::isDynamic(dim))
6443 return false;
6444 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6445 return false;
6446 }
6447 return true;
6448}
6449
6450bool UnPackOp::isLikeUnPad() {
6451 ShapedType packedTensorType = getSourceType();
6452 return isLikePadUnPad(*this, packedTensorType);
6453}
6454
6455::mlir::LogicalResult
6456UnPackOp::fold(FoldAdaptor adaptor,
6457 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6458 // TODO: Support Memref UnPackOp. Temporarily return failure.
6459 if (!hasPureTensorSemantics())
6460 return failure();
6461
6462 if (OpFoldResult reshapedSource = reshapeConstantSource(
6463 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6464 cast<TensorType>(getResult().getType()))) {
6465 results.push_back(reshapedSource);
6466 return success();
6467 }
6468 return failure();
6469}
6470
6471/// Folds a tensor.cast op into a consuming UnPackOp op if the
6472/// `tensor.cast` has source that is more static than the consuming op.
6473///
6474/// Example:
6475/// ```mlir
6476/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6477/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6478/// ```
6479///
6480/// folds into:
6481///
6482/// ```mlir
6483/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6484/// ```
6485struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6486 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6487
6488 LogicalResult matchAndRewrite(UnPackOp op,
6489 PatternRewriter &rewriter) const override {
6490 // TODO: Support Memref UnPackOp. Temporarily return failure.
6491 if (!op.hasPureTensorSemantics())
6492 return failure();
6493
6495 return failure();
6496
6497 SmallVector<Type> newResultTypes(op->getResultTypes());
6498 SmallVector<Value> newOperands =
6500 Value sourceTensor = newOperands[0];
6501
6502 // Get the updated mixed-tile-sizes attribute.
6504 rewriter, sourceTensor.getType(), op.getMixedTiles());
6505
6506 // Clone op.
6507 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6508 // this point. However, in practice, we use them for things that we'd like
6509 // to preserve. Implement a better abstraction.
6510 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6511 newOperands[1], op.getInnerDimsPos(),
6512 newMixedTileSizes, op.getOuterDimsPerm());
6513 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6514
6515 // Replace op.
6516 Value oldResult = op.getResult();
6517 Value newResult = newOp.getResult();
6519 (newResult.getType() != oldResult.getType())
6520 ? tensor::CastOp::create(rewriter, op->getLoc(),
6521 oldResult.getType(), newResult)
6522 : newResult;
6523
6524 rewriter.replaceOp(op, {replacement});
6525
6526 return success();
6527 }
6528};
6529
6530//===----------------------------------------------------------------------===//
6531// BatchReduceMatmulOp
6532//===----------------------------------------------------------------------===//
6533SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6535 utils::IteratorType::reduction, utils::IteratorType::parallel,
6536 utils::IteratorType::parallel, utils::IteratorType::reduction};
6537}
6538
6539SmallVector<AffineMap>
6540BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6541 AffineExpr d0, d1, d2, d3;
6542 SmallVector<AffineMap> indexingMaps;
6543 bindDims(context, d0, d1, d2, d3);
6544 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6545 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6546 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6547 return indexingMaps;
6548}
6549
6550bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6551 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6552 if (!maps)
6553 return false;
6554 if (maps.size() != 3)
6555 return false;
6556 auto positions = getAffineResultPositions(maps);
6557 if (failed(positions))
6558 return false;
6559 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6560 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6561 (*positions)[2] == SmallVector<int64_t>{1, 2};
6562}
6563unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6564
6565std::string BatchReduceMatmulOp::getLibraryCallName() {
6566 return generateLibraryCallName(getOperation());
6567}
6568
6569/// Check if the op has broadcast and/or transpose semantic. Returns true if
6570/// the user defined indexing maps are not equal to default map.
6571bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6572 SmallVector<AffineMap, 3> defaultMaps =
6573 getDefaultIndexingMaps(this->getContext());
6574 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6575 return defaultMaps != explicitMaps;
6576}
6577
6578/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6579/// broadcast map must include K dimension.
6580/// TODO: Strict inclusion of K dimension in the broadcast map is not
6581/// necessary for both input matrices simultaneously. We can relax this
6582/// condition to have K dimension for one input matrix map and infer the K
6583/// dimension for other input matrix map from the one already having K
6584/// dimension.
6585bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6586 bool isLHS) {
6587 assert(bcastMap.getNumResults() < 3 &&
6588 "Expected less than 3 result dim expr.");
6589 bool isValid = false;
6590 enum Indices { batchPos, mPos, nPos, kPos };
6591 if (bcastMap.getNumResults() == 1) {
6592 AffineExpr expr = bcastMap.getResult(0);
6593 isValid = expr.isFunctionOfDim(kPos);
6594 } else if (bcastMap.getNumResults() == 2) {
6595 AffineExpr expr0 = bcastMap.getResult(0);
6596 AffineExpr expr1 = bcastMap.getResult(1);
6597 isValid =
6598 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6599 expr0.isFunctionOfDim(mPos)) &&
6600 expr1.isFunctionOfDim(kPos))
6601 : ((expr0.isFunctionOfDim(batchPos) &&
6602 expr1.isFunctionOfDim(kPos)) ||
6603 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6604 }
6605 return isValid;
6606}
6607
6608void BatchReduceMatmulOp::regionBuilder(
6609 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
6610 function_ref<InFlightDiagnostic()> emitError) {
6611 if (emitError && block.getNumArguments() != 3) {
6612 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6613 << block.getNumArguments();
6614 return;
6615 }
6616 assert(block.getNumArguments() == 3 &&
6617 "BatchReduceMatmulOp regionBuilder expects 3 args");
6618 RegionBuilderHelper helper(b, block);
6619 SmallVector<Value> yields;
6620
6621 auto toType = block.getArgument(2).getType();
6622 Value castValA =
6623 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6624 Value castValB =
6625 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6626 Value mulVal =
6627 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6628 if (!castValA || !castValB || !mulVal)
6629 return;
6630 Value addVal =
6631 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6632 if (!addVal)
6633 return;
6634 yields.push_back(addVal);
6635 helper.yieldOutputs(yields);
6636}
6637
6638ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6639 OperationState &result) {
6640 SmallVector<Attribute, 3> indexingMapsAttr;
6641 Attribute mapAttr;
6642 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6643 if (parser.parseEqual())
6644 return failure();
6645 if (parser.parseLSquare())
6646 return failure();
6647
6648 do {
6649 if (parser.parseAttribute(mapAttr))
6650 return failure();
6651 if (!isa<AffineMapAttr>(mapAttr)) {
6652 return parser.emitError(parser.getCurrentLocation(),
6653 "expected affine map attribute");
6654 }
6655 indexingMapsAttr.push_back(mapAttr);
6656
6657 if (parser.parseOptionalComma())
6658 break;
6659 } while (true);
6660
6661 if (parser.parseRSquare())
6662 return failure();
6663 }
6664 // Initialize indexingMaps, if not supplied explicitly.
6665 if (indexingMapsAttr.empty()) {
6666 indexingMapsAttr = llvm::map_to_vector(
6667 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6668 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6669 }
6670 result.addAttribute("indexing_maps",
6671 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6672 return ::parseNamedStructuredOp(parser, result,
6673 BatchReduceMatmulOp::getNumRegionArgs(),
6674 BatchReduceMatmulOp::getRegionBuilder());
6675}
6676
6677void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6678 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6679 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6680 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6681
6682 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6683 p << " indexing_maps = [";
6684 llvm::interleaveComma(getIndexingMaps(), p,
6685 [&](Attribute attr) { p.printAttribute(attr); });
6686 p << "]";
6687 }
6688
6689 SmallVector<StringRef, 3> elidedAttrs = {
6690 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6691 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6692 elidedAttrs);
6693}
6694
6695/// Verify the user defined indexing maps.
6696LogicalResult BatchReduceMatmulOp::verify() {
6697 // Verification of pure batch_reduce_matmul is handled by
6698 // verifyStructuredOpInterface().
6699 if (!hasUserDefinedMaps())
6700 return success();
6701
6702 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6704 return failure();
6705 }
6706 return success();
6707}
6708LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6709 SmallVectorImpl<OpFoldResult> &) {
6710 return memref::foldMemRefCast(*this);
6711}
6712void BatchReduceMatmulOp::getEffects(
6713 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6714 &effects) {
6715 if (hasPureTensorSemantics())
6716 return;
6717 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6718}
6719
6720Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6721 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6722}
6723
6724} // namespace linalg
6725} // namespace mlir
6726
6727//===----------------------------------------------------------------------===//
6728// LinalgDialect
6729//===----------------------------------------------------------------------===//
6730
6731void LinalgDialect::getCanonicalizationPatterns(
6732 RewritePatternSet &results) const {
6733 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6734 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6735}
6736
6737Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6738 Attribute value, Type type,
6739 Location loc) {
6740 return arith::ConstantOp::materialize(builder, value, type, loc);
6741}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:59
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
result_iterator result_begin()
Definition Operation.h:439
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_type_range getOperandTypes()
Definition Operation.h:423
result_iterator result_end()
Definition Operation.h:440
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override