MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
39
40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
52#include <cassert>
53#include <optional>
54
55using namespace mlir;
56using namespace mlir::linalg;
57
58/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
60 int64_t dim) {
61 auto type = cast<ShapedType>(v.getType());
62 if (!type.isDynamicDim(dim))
63 return builder.getIndexAttr(type.getDimSize(dim));
64
65 return getAsOpFoldResult(
67 .Case([&](RankedTensorType t) -> Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
69 })
70 .Case([&](MemRefType t) -> Value {
71 return memref::DimOp::create(builder, loc, v, dim);
72 }));
73}
74
75/// Returns a memref.subview or a tensor.extract_slice based on the type of the
76/// `source`.
80 ArrayRef<OpFoldResult> strides) {
82 .Case([&](RankedTensorType t) -> Operation * {
83 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
84 strides);
85 })
86 .Case([&](MemRefType type) -> Operation * {
87 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
88 strides);
89 })
90 .Default([&](Type t) -> Operation * { return nullptr; });
91}
92
93//===----------------------------------------------------------------------===//
94// Helper functions
95//===----------------------------------------------------------------------===//
96
98 int64_t dim) {
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable("Expected MemRefType or TensorType");
104}
105
107 int64_t dim) {
108 auto shapedType = llvm::cast<ShapedType>(source.getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110 return createOrFoldDimOp(b, loc, source, dim);
111 return b.getIndexAttr(shapedType.getDimSize(dim));
112}
113
114//===----------------------------------------------------------------------===//
115// Support for named Linalg ops defined in ods-gen.
116//===----------------------------------------------------------------------===//
117
121
122/// Fills the region of a structured operation using the provided
123/// `regionBuilder`. The method is used by both named structured ops created by
124/// ods-gen and by manually defined C++ ops. It is called by both builders and
125/// parsers and creates a block with arguments corresponding to the elemental
126/// types of `inputTypes` and `outputTypes`.
127static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
128 TypeRange inputTypes, TypeRange outputTypes,
131 RegionBuilderFn regionBuilder) {
132 SmallVector<Type, 8> argTypes;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(opBuilder.getUnknownLoc());
141 }
142 }
143
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs, emitError);
152
153 // indexing_maps is an auto-generated method.
154
155 // iterator_types is an auto-generated method.
156}
157
158/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159/// The result types are derived automatically if `resultTensorTypes` is none.
160/// The body of the operation is filled using `regionBuilder`. All ods-gen
161/// created structured operations use the method to implement their builders.
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
173
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
177
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
183
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187 state.attributes.getAttrs(), /*emitError=*/{},
188 regionBuilder);
189}
190
192 std::optional<TypeRange> resultTensorTypes,
193 ValueRange inputs, ValueRange outputs,
194 ArrayRef<NamedAttribute> attributes,
195 RegionBuilderFn regionBuilder,
196 ArrayRef<AffineMap> defaultIndexingMaps) {
197 // If indexing maps are not provided, apply the default ones.
198 if (none_of(attributes, [](NamedAttribute attr) {
199 return attr.getName() == "indexing_maps";
200 })) {
201 SmallVector<Attribute, 3> indexingMapsAttrVal;
202 indexingMapsAttrVal = llvm::map_to_vector(
203 defaultIndexingMaps,
204 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
205 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
206 }
207 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
208 attributes, regionBuilder);
209}
210
212 std::optional<TypeRange> resultTensorTypes,
213 ValueRange inputs, ValueRange outputs,
214 ArrayRef<NamedAttribute> attributes,
215 RegionBuilderFn regionBuilder,
216 ArrayRef<AffineMap> defaultIndexingMaps) {
217 // If indexing maps are not provided, apply the default ones.
218 if (none_of(attributes, [](NamedAttribute attr) {
219 return attr.getName() == "indexing_maps";
220 })) {
221 SmallVector<Attribute, 4> indexingMapsAttrVal;
222 indexingMapsAttrVal = llvm::map_to_vector(
223 defaultIndexingMaps,
224 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
225 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
226 }
227 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
228 attributes, regionBuilder);
229}
230
232 std::optional<TypeRange> resultTensorTypes,
233 ValueRange inputs, ValueRange outputs,
234 ArrayRef<NamedAttribute> attributes,
235 RegionBuilderFn regionBuilder,
236 ArrayRef<AffineMap> indexingMaps) {
237 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
238 SmallVector<Attribute, 4> indexingMapsAttrVal;
239 indexingMapsAttrVal =
240 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
241 return AffineMapAttr::get(map);
242 });
243 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
244 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
245 attributes, regionBuilder);
246}
247
248/// Common parsing used for both named structured ops created by ods-gen and by
249/// manually defined C++ ops. Does not handle regions.
250static ParseResult
252 SmallVectorImpl<Type> &inputTypes,
253 SmallVectorImpl<Type> &outputTypes,
254 bool addOperandSegmentSizes = true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
257 outputsOperands;
258
259 if (succeeded(parser.parseOptionalLess())) {
260 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
261 return failure();
262 }
263 attrsLoc = parser.getCurrentLocation();
264 if (parser.parseOptionalAttrDict(result.attributes))
265 return failure();
266
267 if (succeeded(parser.parseOptionalKeyword("ins"))) {
268 if (parser.parseLParen())
269 return failure();
270
271 inputsOperandsLoc = parser.getCurrentLocation();
272 if (parser.parseOperandList(inputsOperands) ||
273 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
274 return failure();
275 }
276
277 if (succeeded(parser.parseOptionalKeyword("outs"))) {
278 outputsOperandsLoc = parser.getCurrentLocation();
279 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
280 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
281 return failure();
282 }
283
284 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
285 result.operands) ||
286 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
287 result.operands))
288 return failure();
289
290 if (addOperandSegmentSizes) {
291 // This is a bit complex because we're trying to be backward compatible with
292 // operation syntax that mix the inherent attributes and the discardable
293 // ones in the same dictionary. If the properties are used, we append the
294 // operandSegmentSizes there directly. Otherwise we append it to the
295 // discardable attributes dictionary where it is handled by the generic
296 // Operation::create(...) method.
297 if (result.propertiesAttr) {
298 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
299 attrs.append("operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
303 result.propertiesAttr = attrs.getDictionary(parser.getContext());
304 } else {
305 result.addAttribute("operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
309 }
310 }
311 if (!result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
314 if (info) {
315 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 << "'" << result.name.getStringRef() << "' op ";
318 })))
319 return failure();
320 }
321 }
322 return success();
323}
324
326 ValueRange outputs) {
327 if (!inputs.empty())
328 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
329 if (!outputs.empty())
330 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
331}
332
333//===----------------------------------------------------------------------===//
334// Specific parsing and printing for named structured ops created by ods-gen.
335//===----------------------------------------------------------------------===//
336
338 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
339 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
340 RegionBuilderFn regionBuilder, SMLoc loc) {
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
342 return parser.emitError(
343 parser.getCurrentLocation(),
344 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
347 }
348
349 OpBuilder opBuilder(parser.getContext());
350 ParseResult result = success();
352 opBuilder, region, inputTypes, outputTypes, attrs,
353 [&]() {
354 result = failure();
355 return parser.emitError(loc);
356 },
357 regionBuilder);
358 return result;
359}
360
361static ParseResult
363 SmallVectorImpl<Type> &resultTypes) {
364 if (parser.parseOptionalArrowTypeList(resultTypes))
365 return failure();
366 return success();
367}
368
369static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
371 unsigned numRegionArgs,
372 RegionBuilderFn regionBuilder) {
373 // TODO: Enable when ods-gen supports captures.
374 SmallVector<Type, 1> inputTypes, outputTypes;
375 SMLoc loc = parser.getCurrentLocation();
376 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
377 return failure();
378
379 // Parse optional attributes.
380 if (parser.parseOptionalAttrDict(result.attributes))
381 return failure();
382
383 // TODO: consider merging results parsing into region parsing.
384 // Need to wait for declarative assembly resolution to decide.
385 SmallVector<Type, 1> outputTensorsTypes;
386 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
387 return failure();
388 result.addTypes(outputTensorsTypes);
389
390 std::unique_ptr<Region> region = std::make_unique<Region>();
391 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
392 outputTypes, result.attributes.getAttrs(),
393 regionBuilder, loc))
394 return failure();
395 result.addRegion(std::move(region));
396
397 return success();
398}
399
401 TypeRange resultTypes) {
402 if (resultTypes.empty())
403 return;
404 p.printOptionalArrowTypeList(resultTypes);
405}
406
408 ValueRange inputs, ValueRange outputs,
409 ArrayRef<StringRef> elidedAttrs = {}) {
410 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
411
412 // Printing is shared with generic ops, except for the region and
413 // attributes.
414 printCommonStructuredOpParts(p, inputs, outputs);
415
416 // Results printing.
418
419 // Region is elided.
420}
421
422//===----------------------------------------------------------------------===//
423// Region builder helper.
424// TODO: Move this to a utility library.
425// The public methods on this class are referenced directly from generated code.
426// Helper build the unary, binary, and type conversion functions defined by the
427// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
428// class.
429//
430// Implementations of the math functions must be polymorphic over numeric types,
431// internally performing necessary casts. If the function application makes no
432// sense, then the only recourse is to assert and return nullptr. This can be
433// extended later if it becomes possible to fail construction of the region. The
434// invariant should be enforced at a higher level.
435//
436// TODO: These helpers are currently type polymorphic over the class of integer
437// and floating point types, but they will not internally cast within bit
438// widths of a class (mixed precision such as i8->i32) or across classes
439// (i.e. mixed float and integer). Many such combinations are ambiguous or need
440// to be handled with care and work is being considered to extend the op
441// language to make such cases explicit. In the mean-time, violating this will
442// fail verification, which is deemed acceptable.
443//===----------------------------------------------------------------------===//
444
445namespace {
446
447class RegionBuilderHelper {
448public:
449 RegionBuilderHelper(OpBuilder &builder, Block &block)
450 : builder(builder), block(block) {}
451
452 // Build the unary functions defined by OpDSL.
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
454 function_ref<InFlightDiagnostic()> emitError = {}) {
455 if (!isFloatingPoint(arg)) {
456 if (emitError) {
457 emitError() << "unsupported non numeric type";
458 return nullptr;
459 }
460 llvm_unreachable("unsupported non numeric type");
461 }
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
464 switch (unaryFn) {
465 case UnaryFn::exp:
466 return math::ExpOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::log:
468 return math::LogOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::abs:
470 return math::AbsFOp::create(builder, arg.getLoc(), arg);
471 case UnaryFn::ceil:
472 return math::CeilOp::create(builder, arg.getLoc(), arg);
473 case UnaryFn::floor:
474 return math::FloorOp::create(builder, arg.getLoc(), arg);
475 case UnaryFn::negf:
476 return arith::NegFOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.getType());
479 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
482 }
483 case UnaryFn::round:
484 return math::RoundOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::sqrt:
486 return math::SqrtOp::create(builder, arg.getLoc(), arg);
487 case UnaryFn::rsqrt:
488 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
491 case UnaryFn::tanh:
492 return math::TanhOp::create(builder, arg.getLoc(), arg);
493 case UnaryFn::erf:
494 return math::ErfOp::create(builder, arg.getLoc(), arg);
495 }
496 if (emitError) {
497 emitError() << "unsupported unary function";
498 return nullptr;
499 }
500 llvm_unreachable("unsupported unary function");
501 }
502
503 // Build the binary functions defined by OpDSL.
504 // If emitError is provided, an error will be emitted if the operation is not
505 // supported and a nullptr will be returned, otherwise an assertion will be
506 // raised.
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
508 function_ref<InFlightDiagnostic()> emitError = {}) {
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
512 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
513 arg1.getType().getIntOrFloatBitWidth() == 1;
514 if (!allComplex && !allFloatingPoint && !allInteger) {
515 if (emitError) {
516 emitError()
517 << "Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
519 << arg0.getType() << " and " << arg1.getType();
520 return nullptr;
521 }
522 llvm_unreachable("unsupported non numeric type");
523 }
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
526 switch (binaryFn) {
527 case BinaryFn::add:
528 if (allComplex)
529 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool)
533 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
535 case BinaryFn::sub:
536 if (allComplex)
537 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
540 if (allBool) {
541 if (emitError) {
542 emitError() << "unsupported operation: sub with bools";
543 return nullptr;
544 }
545 llvm_unreachable("unsupported operation: sub with bools");
546 }
547 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::mul:
549 if (allComplex)
550 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool)
554 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
556 case BinaryFn::div:
557 if (allComplex)
558 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
561 if (allBool) {
562 if (emitError) {
563 emitError() << "unsupported operation: div with bools";
564 return nullptr;
565 }
566 llvm_unreachable("unsupported operation: div with bools");
567 }
568 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
571 if (emitError) {
572 emitError() << "unsupported operation: unsigned div not on uint";
573 return nullptr;
574 }
575 llvm_unreachable("unsupported operation: unsigned div not on uint");
576 }
577 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
579 assert(!allComplex);
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
584 assert(!allComplex);
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
589 assert(!allComplex);
590 if (!allInteger || allBool) {
591 if (emitError) {
592 emitError() << "unsupported operation: unsigned max not on uint";
593 return nullptr;
594 }
595 llvm_unreachable("unsupported operation: unsigned max not on uint");
596 }
597 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
599 assert(!allComplex);
600 if (!allInteger || allBool) {
601 if (emitError) {
602 emitError() << "unsupported operation: unsigned min not on uint";
603 return nullptr;
604 }
605 llvm_unreachable("unsupported operation: unsigned min not on uint");
606 }
607 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
608 case BinaryFn::powf:
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
611 }
612 if (emitError) {
613 emitError() << "unsupported binary function";
614 return nullptr;
615 }
616 llvm_unreachable("unsupported binary function");
617 }
618
619 // Build the ternary functions defined by OpDSL.
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
621 function_ref<InFlightDiagnostic()> emitError = {}) {
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
624 switch (ternaryFn) {
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
627 }
628 if (emitError) {
629 emitError() << "unsupported ternary function";
630 return nullptr;
631 }
632 llvm_unreachable("unsupported ternary function");
633 }
634
635 // Build the type functions defined by OpDSL.
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
637 function_ref<InFlightDiagnostic()> emitError = {}) {
638 switch (typeFn) {
639 case TypeFn::cast_signed:
640 return cast(toType, operand, false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand, true);
643 }
644 if (emitError) {
645 emitError() << "unsupported type conversion function";
646 return nullptr;
647 }
648 llvm_unreachable("unsupported type conversion function");
649 }
650
651 void yieldOutputs(ValueRange values) {
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
656 }
657
658 Value constant(const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr = parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
665 }
666
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
671 }
672
673 Type getIntegerType(unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
675 }
676
677 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
679
680private:
681 // Generates operations to cast the given operand to a specified type.
682 // If the cast cannot be performed, a warning will be issued and the
683 // operand returned as-is (which will presumably yield a verification
684 // issue downstream).
685 Value cast(Type toType, Value operand, bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.getLoc();
689 if (isa<UnknownLoc>(loc)) {
690 if (operand.getDefiningOp())
691 loc = operand.getDefiningOp()->getLoc();
692 else if (operand.getParentBlock() &&
693 operand.getParentBlock()->getParentOp())
694 loc = operand.getParentBlock()->getParentOp()->getLoc();
695 }
696 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
697 }
698
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.getType());
701 }
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.getType());
704 }
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.getType());
707 }
708
709 OpBuilder &builder;
710 Block &block;
711};
712
713} // namespace
714
715//===----------------------------------------------------------------------===//
716// CopyOp
717//===----------------------------------------------------------------------===//
718
719namespace {
720
721struct EraseSelfCopy : OpRewritePattern<CopyOp> {
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter) const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
726 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
727 if (copyOp.hasPureBufferSemantics())
728 rewriter.eraseOp(copyOp);
729 else
730 rewriter.replaceOp(copyOp, copyOp.getInputs());
731
732 return success();
733 }
734};
735
736} // namespace
737
738void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.add<EraseSelfCopy>(context);
741}
742
743//===----------------------------------------------------------------------===//
744// FillOp
745//===----------------------------------------------------------------------===//
746
747namespace {
748
749/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
750///
751/// For such op chains, we can create new linalg.fill ops with the result
752/// type of the tensor.expand/collapse_shape op.
753template <typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter) const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
759 if (!oldFill)
760 return failure();
761
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
765
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
770 } else {
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
774 }
775 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
776 ValueRange{newInit});
777 return success();
778 }
779};
780
781/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
782/// filling value are the same.
783struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
785
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter) const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
789 if (!fillOp)
790 return failure();
791
792 // We can only fold if the padding value is the same as the original
793 // filling value.
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
796 return failure();
797
798 ReifiedRankedShapedTypeDims reifiedShape;
799 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
800 return rewriter.notifyMatchFailure(
801 padOp, "failed to reify tensor.pad op result shape");
802
803 auto emptyTensor =
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
806 Value replacement =
807 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
808 ValueRange{emptyTensor})
809 .getResult(0);
810 if (replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
812 padOp.getResultType(), replacement);
813 }
814 rewriter.replaceOp(padOp, replacement);
815 return success();
816 }
817};
818
819/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
820/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
821/// filling value are the same.
822struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
824
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter) const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
828 if (!srcPadOp)
829 return failure();
830
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
832 return failure();
833
834 // Walk back the tensor.insert_slice chain and find the first destination
835 // value at the start of the chain.
836 Value firstDest = insertOp.getDest();
837 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
839 return failure();
840
841 // Make sure the range of values accessed are disjoint. Without this, we
842 // cannot fold tensor.pad away.
843 bool disjoint = false;
844 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
845 // If the dimension has dynamic offset/size, we cannot guarantee
846 // disjoint. So just skip it.
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
850 continue;
851
852 // Get the range start and end, inclusively for both.
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
860 disjoint = true;
861 break;
862 }
863 }
864
865 if (!disjoint)
866 break;
867 firstDest = prevOp.getDest();
868 }
869
870 // Check whether the first destination is a fill op. For overlapped cases,
871 // this also cannot be true.
872 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
873 if (!dstFillOp)
874 return failure();
875
876 // We can only fold if the padding value is the same as the original
877 // filling value.
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
880 return failure();
881
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
884
885 Location loc = insertOp.getLoc();
886 MLIRContext *context = getContext();
887
888 AffineExpr sym0, sym1;
889 bindSymbols(context, sym0, sym1);
890 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
891
892 // Calculate the new offsets for the insert. It should be the old offsets
893 // plus low padding sizes.
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
898 }
899
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
904 newSizes.push_back(
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
906 .getResult());
907 } else {
908 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
909 }
910 }
911
912 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
915 return success();
916 }
917};
918
919/// Fold tensor.extract(linalg.fill(<input>)) into <input>
920struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
921public:
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
923
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter) const override {
926 // See if tensor input of tensor.extract op is the result of a linalg.fill
927 // op.
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
929 if (!fillOp)
930 return failure();
931
932 // Get scalar input operand of linalg.fill op.
933 Value extractedScalar = fillOp.getInputs()[0];
934
935 // Replace tensor.extract op with scalar value used to fill the tensor.
936 rewriter.replaceOp(extractOp, extractedScalar);
937 return success();
938 }
939};
940
941/// Folds pack(fill) into a single fill op if
942/// 1. The pack op does not have padding value, or
943/// 2. The filled value and padding value are the same.
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
947 if (!fillOp)
948 return failure();
949
950 if (auto paddingValue = packOp.getPaddingValue())
951 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
952 return failure();
953
954 Value packOpDest = packOp.getDest();
955 if (!packOpDest.hasOneUse())
956 return failure();
957
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
959 packOp.getDest());
960}
961
962/// Wrapper pattern that applies foldFillPackIntoFillOp method.
963struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
964public:
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
967
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter) const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
971 if (failed(fillOp))
972 return failure();
973 rewriter.replaceOp(packOp, fillOp.value().result());
974 return success();
975 }
976};
977
978/// Fold fill with copy.
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
981
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter) const override {
984 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
985 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
986 fillOp.getInputs(),
987 copyOp.getOutputs());
988 return success();
989 }
990 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
991 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
992 fillOp.getOutputs());
993 return success();
994 }
995 return failure();
996 }
997};
998
999/// Fold fill with transpose.
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter) const override {
1005 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1006 rewriter.replaceOpWithNewOp<FillOp>(
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1009 return success();
1010 }
1011 return failure();
1012 }
1013};
1014
1015/// Fold a concat with all elements being fills of the same value
1016/// into a fill of the concat result shape.
1017struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1019
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter) const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1024 return failure();
1025 }
1026
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1028 if (!firstFillOp) {
1029 return failure();
1030 }
1031 // Prefetch the fill value.
1032 OpFoldResult firstFillVal =
1033 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1034 // Collect all the outs values for the fill operations.
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1037
1038 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1040 if (!fillOp) {
1041 return false;
1042 }
1043
1044 OpFoldResult fillVal =
1045 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1046 if (fillVal != firstFillVal)
1047 return false;
1048
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1050 return true;
1051 };
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1054 return rewriter.notifyMatchFailure(
1055 concatOp, "not all operands are defined by a compatible fill op");
1056 }
1057
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1060 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1061 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1062 return success();
1063 }
1064};
1065
1066} // namespace
1067
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// GenericOp
1079//===----------------------------------------------------------------------===//
1080
1082 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1083 ValueRange outputs,
1084 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1085 SmallVector<Type, 4> blockArgTypes;
1086 SmallVector<Location, 4> blockArgLocs;
1087 for (ValueRange container : {inputs, outputs}) {
1088 for (Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1091 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1092 blockArgLocs.push_back(v.getLoc());
1093 }
1094 }
1095
1096 OpBuilder::InsertionGuard guard(builder);
1097 Block *bodyBlock =
1098 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1099 bodyBuild(builder, loc, bodyBlock->getArguments());
1100}
1101
1102void GenericOp::getAsmBlockArgumentNames(Region &region,
1103 OpAsmSetValueNameFn setNameFn) {
1104 for (Value v : getRegionInputArgs())
1105 setNameFn(v, "in");
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v, "out");
1108}
1109
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1112 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1114 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1119 if (bodyBuild)
1120 buildGenericRegion(builder, result.location, *result.regions.front(),
1121 inputs, outputs, bodyBuild);
1122}
1123
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1126 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1129 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder, result, resultTensorTypes, inputs, outputs,
1132 builder.getAffineMapArrayAttr(indexingMaps),
1133 builder.getArrayAttr(llvm::map_to_vector(
1134 iteratorTypes,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1137 })),
1138 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1139 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1148 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1149 ArrayRef<NamedAttribute> attributes) {
1150 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, ValueRange inputs,
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1161 /*doc=*/"",
1162 /*libraryCall=*/"", bodyBuild, attributes);
1163}
1164
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1167 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1169 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1172 iteratorTypes,
1173 /*doc=*/"",
1174 /*libraryCall=*/"", bodyBuild, attributes);
1175}
1176
1177void GenericOp::print(OpAsmPrinter &p) {
1178 p << " ";
1179
1180 // Print extra attributes.
1181 auto genericAttrNames = linalgTraitAttrNames();
1182
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1191 // Convert IteratorType enums into the string representation. This is
1192 // needed, because tests still use the old format when 'iterator_types'
1193 // attribute is represented as an array of strings.
1194 // TODO: Remove this conversion once tests are fixed.
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(getContext(), stringifyIteratorType(t));
1198 });
1199
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(getContext(), iteratorTypeNames));
1203 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1205 }
1206 }
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1209 p << genericDictAttr;
1210 }
1211
1212 // Printing is shared with named ops, except for the region and attributes
1213 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1214
1215 genericAttrNames.push_back("operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1217
1218 bool hasExtraAttrs = false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1221 break;
1222 }
1223 if (hasExtraAttrs) {
1224 p << " attrs = ";
1225 p.printOptionalAttrDict((*this)->getAttrs(),
1226 /*elidedAttrs=*/genericAttrNames);
1227 }
1228
1229 // Print region.
1230 if (!getRegion().empty()) {
1231 p << ' ';
1232 p.printRegion(getRegion());
1233 }
1234
1235 // Print results.
1236 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1237}
1238
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1240 DictionaryAttr dictAttr;
1241 // Parse the core linalg traits that must check into a dictAttr.
1242 // The name is unimportant as we will overwrite result.attributes.
1243 // The core linalg traits must contain the information necessary to pass the
1244 // verifier.
1245 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1246 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1247 return failure();
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1250
1251 // Convert array of string into an array of IteratorType enums. This is
1252 // needed, because tests still use the old format when 'iterator_types'
1253 // attribute is represented as an array of strings.
1254 // TODO: Remove this conversion once tests are fixed.
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(result.name)));
1257 if (!iteratorTypes) {
1258 return parser.emitError(attributeLocation)
1259 << "expected " << getIteratorTypesAttrName(result.name)
1260 << " array attribute";
1261 }
1262
1263 SmallVector<Attribute> iteratorTypeAttrs;
1264
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1268 return parser.emitError(parser.getCurrentLocation())
1269 << "unexpected iterator_type (" << s << ")";
1270
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1273 }
1274 result.attributes.set(getIteratorTypesAttrName(result.name),
1275 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1276
1277 // Parsing is shared with named ops, except for the region.
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1279 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1280 return failure();
1281
1282 // Optional attributes may be added.
1283 if (succeeded(parser.parseOptionalKeyword("attrs")))
1284 if (failed(parser.parseEqual()) ||
1285 failed(parser.parseOptionalAttrDict(result.attributes)))
1286 return failure();
1287
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 if (parser.parseRegion(*region, {}))
1290 return failure();
1291 result.addRegion(std::move(region));
1292
1293 // Generic ops may specify that a subset of its outputs are tensors. Such
1294 // outputs are specified in the result type.
1295 // TODO: may need to move output parsing before region parsing.
1296 // Need to wait for declarative assembly resolution to decide.
1297 SmallVector<Type, 1> outputTensorsTypes;
1298 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1299 return failure();
1300 result.addTypes(outputTensorsTypes);
1301
1302 return success();
1303}
1304
1307 &effects,
1308 LinalgOp linalgOp) {
1309 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.getType()))
1311 continue;
1312 effects.emplace_back(
1313 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1314 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1315 }
1316
1317 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().getType()))
1319 continue;
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1322 /*effectOnFullRegion=*/true,
1324 }
1325 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1326 /*effectOnFullRegion=*/true,
1328 }
1329}
1330
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1333 &effects) {
1334 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1335}
1336
1339 // Operands with value semantics are speculatable, while operands with memory
1340 // semantics are not.
1341 if (!linalgOp.hasPureTensorSemantics())
1343 // The body of the op can still have speculation in its region.
1345}
1346
1347Speculation::Speculatability GenericOp::getSpeculatability() {
1348 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1349}
1350
1351namespace {
1352
1353/// Remove linalg operations that are just copying the values from inputs to
1354/// results. In the memref case, the operation must be copying to and from the
1355/// same value. Requirements are:
1356/// 1) All iterator types are parallel
1357/// 2) The body contains just a yield operation with the yielded values being
1358/// the arguments corresponding to the operands.
1359template <typename OpTy>
1360struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1362
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter) const override {
1365 // All indexing maps must be equal. It follows that they are permutations.
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1367 return failure();
1368
1369 // Check that the body of the linalg operation is just a linalg.yield
1370 // operation.
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1373 return failure();
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1375 if (!yieldOp)
1376 return failure();
1377
1378 // In the buffer case, we need to check exact buffer equality.
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1383 return rewriter.notifyMatchFailure(
1384 linalgOp, "expected single input and output to be the same value");
1385 }
1386
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1389 return rewriter.notifyMatchFailure(linalgOp,
1390 "cannot fold fill-like op");
1391 }
1392
1393 rewriter.eraseOp(linalgOp);
1394 return success();
1395 }
1396
1397 if (!linalgOp.hasPureTensorSemantics()) {
1398 return rewriter.notifyMatchFailure(
1399 linalgOp, "mixed semantics is not supported yet");
1400 }
1401
1402 // Get the argument number of the returned values. That is the operand
1403 // number to use for replacing uses of this operation.
1404 SmallVector<Value> returnedArgs;
1405 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1408 return failure();
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1412 // The input can have a different type than the result, e.g. a dynamic
1413 // input dimension can be turned into a static output dimension.
1414 Type returnType = returnedArg.getType();
1415 if (returnType != resultType) {
1416 // Distinguish between sparse conversion or dense tensor casting.
1417 // TODO: unify the two ops?
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1422 else {
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1424 resultType))
1425 return failure();
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1428 }
1429 }
1430 returnedArgs.push_back(returnedArg);
1431 }
1432
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1434 return failure();
1435 rewriter.replaceOp(linalgOp, returnedArgs);
1436 return success();
1437 }
1438};
1439
1440} // namespace
1441
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1445}
1446
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1448 return memref::foldMemRefCast(*this);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// MapOp
1453//===----------------------------------------------------------------------===//
1454
1455static ParseResult parseDstStyleOp(
1457 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1458 nullptr) {
1459 // Parse `ins` and `outs`.
1460 SmallVector<Type, 4> inputTypes, outputTypes;
1461 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1462 /*addOperandSegmentSizes=*/false))
1463 return failure();
1464
1465 // Add result types.
1466 for (Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1469 }
1470
1471 // Parse required attributes.
1472 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1473 return failure();
1474
1475 // Parse optional attributes.
1476 if (parser.parseOptionalAttrDict(result.attributes))
1477 return failure();
1478 return success();
1479}
1480
1481void MapOp::getAsmBlockArgumentNames(Region &region,
1482 OpAsmSetValueNameFn setNameFn) {
1483 for (Value v : getRegionInputArgs())
1484 setNameFn(v, "in");
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v, "init");
1487}
1488
1489void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(), "mapped");
1492}
1493
1494void MapOp::build(
1495 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1496 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1497 ArrayRef<NamedAttribute> attributes) {
1498 build(builder, result, TypeRange{}, inputs, init);
1499 result.addAttributes(attributes);
1500
1501 // Add output types for `RankedTensorType` output arguments.
1502 Type initType = init.getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1505
1506 if (bodyBuild)
1507 buildGenericRegion(builder, result.location, *result.regions.front(),
1508 inputs, /*outputs=*/{init}, bodyBuild);
1509}
1510
1512 const OperationName &payloadOpName,
1513 const NamedAttrList &payloadOpAttrs,
1514 ArrayRef<Value> operands,
1515 bool initFirst = false, bool mapInit = true) {
1516 OpBuilder b(parser.getContext());
1517 Region *body = result.addRegion();
1518 Block &block = body->emplaceBlock();
1519 b.setInsertionPointToStart(&block);
1520 for (auto &operand : operands) {
1521 block.addArgument(
1522 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1523 b.getUnknownLoc());
1524 }
1525 SmallVector<Value> payloadOpOperands;
1526 // If initFirst flag is enabled, we consider init as the first position of
1527 // payload operands.
1528 if (initFirst) {
1529 if (mapInit)
1530 payloadOpOperands.push_back(block.getArguments().back());
1531 for (const auto &arg : block.getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1533 } else {
1534 payloadOpOperands = {block.getArguments().begin(),
1535 block.getArguments().end() - int(!mapInit)};
1536 }
1537
1538 Operation *payloadOp = b.create(
1539 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1540 payloadOpOperands,
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1542 .getElementType()},
1543 payloadOpAttrs);
1544 YieldOp::create(b, result.location, payloadOp->getResults());
1545}
1546
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1550 if (succeeded(parser.parseOptionalLBrace())) {
1551 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1552 if (failed(operationName))
1553 return failure();
1554 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1555 return failure();
1556 payloadOpName = operationName.value();
1557 if (parser.parseRBrace())
1558 return failure();
1559 }
1560
1561 if (parseDstStyleOp(parser, result))
1562 return failure();
1563
1564 if (payloadOpName.has_value()) {
1565 if (!result.operands.empty())
1566 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1567 payloadOpAttrs, ArrayRef(result.operands), false,
1568 false);
1569 else
1570 result.addRegion();
1571 } else {
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1573 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1574 /*allowType=*/true, /*allowAttrs=*/true)) {
1575 return failure();
1576 }
1577 Region *body = result.addRegion();
1578 if (parser.parseRegion(*body, regionArgs))
1579 return failure();
1580 }
1581 return success();
1582}
1583
1584static bool canUseShortForm(Block *body, bool initFirst = false,
1585 bool mapInit = true) {
1586 // `intFirst == true` implies that we want to map init arg
1587 if (initFirst && !mapInit)
1588 return false;
1589 // Check if the body can be printed in short form. The following 4 conditions
1590 // must be satisfied:
1591
1592 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1593 if (body->getOperations().size() != 2)
1594 return false;
1595 Operation &payload = body->getOperations().front();
1596
1597 // 2) The payload op must have the same number of operands as the number of
1598 // block arguments.
1599 if (payload.getNumOperands() == 0 ||
1600 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1601 return false;
1602
1603 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1604 // must be the first operand of the payload op, otherwise, the operands
1605 // must match the block arguments in order.
1606 if (initFirst) {
1607 // check init
1608 if (payload.getOperands().back() != body->getArgument(0))
1609 return false;
1610 // check rest
1611 for (const auto &[operand, bbArg] :
1612 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1613 if (bbArg != operand)
1614 return false;
1615 }
1616 } else {
1617 for (const auto &[operand, bbArg] :
1618 llvm::zip(payload.getOperands(),
1619 body->getArguments().drop_back(int(!mapInit)))) {
1620 if (bbArg != operand)
1621 return false;
1622 }
1623 }
1624
1625 // 4) The `yield` operand must be the result of the payload op.
1626 auto yieldOp = cast<YieldOp>(body->getTerminator());
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1630}
1631
1632static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1633 SmallVector<StringRef> elidedAttrs;
1634 std::string attrToElide;
1635 p << " { " << payloadOp->getName().getStringRef();
1636 for (const auto &attr : payloadOp->getAttrs()) {
1637 auto fastAttr =
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1642 break;
1643 }
1644 }
1645 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1646 p << " }";
1647}
1648
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1651 bool useShortForm =
1652 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1653 if (useShortForm) {
1654 printShortForm(p, &mapper->getOperations().front());
1655 }
1656
1657 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1658 p.printOptionalAttrDict((*this)->getAttrs());
1659
1660 if (!useShortForm) {
1661 // Print region if the payload op was not detected.
1662 p.increaseIndent();
1663 p.printNewline();
1664 p << "(";
1665 llvm::interleaveComma(mapper->getArguments(), p,
1666 [&](auto arg) { p.printRegionArgument(arg); });
1667 p << ") ";
1668
1669 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1670 p.decreaseIndent();
1671 }
1672}
1673
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1677
1678 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1679 // region.
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() << "expects number of operands to match the arity of "
1682 "mapper, but got: "
1683 << getInputs().size() + 1 << " and "
1684 << blockArgs.size();
1685
1686 // The parameters of mapper should all match the element type of inputs.
1687 for (const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() << "expected element type of input " << inputElemType
1693 << " to match bbArg type " << bbArgType;
1694 }
1695 }
1696
1697 // The shape of each input must match the shape of the output.
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType : TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() << "expected shape of input (" << inputElemShape
1703 << ") to match shape of output (" << outputShape
1704 << ")";
1705 }
1706 }
1707
1708 return success();
1709}
1710
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1714}
1715
1716ArrayAttr MapOp::getIndexingMaps() {
1717 Builder builder(getContext());
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1720 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1721 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1722}
1723
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1726 &effects) {
1727 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1728}
1729
1730Speculation::Speculatability MapOp::getSpeculatability() {
1731 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// ReduceOp
1736//===----------------------------------------------------------------------===//
1737
1738void ReduceOp::getAsmBlockArgumentNames(Region &region,
1739 OpAsmSetValueNameFn setNameFn) {
1740 for (Value v : getRegionInputArgs())
1741 setNameFn(v, "in");
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v, "init");
1744}
1745
1746void ReduceOp::getAsmResultNames(
1747 function_ref<void(Value, StringRef)> setNameFn) {
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(), "reduced");
1750}
1751
1752void ReduceOp::build(
1753 OpBuilder &builder, OperationState &result, ValueRange inputs,
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1755 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1756 ArrayRef<NamedAttribute> attributes) {
1757 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1758 result.addAttributes(attributes);
1759
1760 // Add output types for `RankedTensorType` output arguments.
1761 for (Value init : inits) {
1762 Type initType = init.getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1765 }
1766
1767 if (bodyBuild)
1768 buildGenericRegion(builder, result.location, *result.regions.front(),
1769 inputs, inits, bodyBuild);
1770}
1771
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1773 int64_t inputRank =
1774 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1780}
1781
1782ArrayAttr ReduceOp::getIndexingMaps() {
1783 int64_t inputRank =
1784 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1786 getNumDpsInputs(),
1788 AffineMap resultMap =
1790 .dropResults(getDimensions());
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1794}
1795
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1798 &effects) {
1799 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1800}
1801
1802Speculation::Speculatability ReduceOp::getSpeculatability() {
1803 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1804}
1805
1806static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1807 NamedAttrList &attributes,
1808 StringRef attributeName) {
1809 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1810 return failure();
1811
1812 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1813 return success();
1814}
1815
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1819 if (succeeded(parser.parseOptionalLBrace())) {
1820 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1821 if (failed(operationName))
1822 return failure();
1823 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1824 return failure();
1825 payloadOpName = operationName.value();
1826 if (parser.parseRBrace())
1827 return failure();
1828 }
1829
1830 if (parseDstStyleOp(
1831 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1832 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1833 }))
1834 return failure();
1835
1836 if (payloadOpName.has_value()) {
1837 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1838 ArrayRef(result.operands), /*initFirst=*/true);
1839 } else {
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1841 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1842 /*allowType=*/true, /*allowAttrs=*/true)) {
1843 return failure();
1844 }
1845
1846 Region *body = result.addRegion();
1847 if (parser.parseRegion(*body, regionArgs))
1848 return failure();
1849 }
1850
1851 return success();
1852}
1853
1854static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1855 ArrayRef<int64_t> attributeValue) {
1856 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1857}
1858
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1861 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1862 if (useShortForm) {
1863 printShortForm(p, &mapper->getOperations().front());
1864 }
1865
1866 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1867 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1868 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1869 if (!useShortForm) {
1870 // Print region if the payload op was not detected.
1871 p.increaseIndent();
1872 p.printNewline();
1873 p << "(";
1874 llvm::interleaveComma(mapper->getArguments(), p,
1875 [&](auto arg) { p.printRegionArgument(arg); });
1876 p << ") ";
1877
1878 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1879 p.decreaseIndent();
1880 }
1881}
1882
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1885
1886 // The ReduceOp uses `SameVariadicOperandSize`, which requires equal numbers
1887 // of inputs and inits. Detect a mismatch early: when they differ, the
1888 // ODS-generated getInputs()/getInits() accessors compute each group's size
1889 // via floordiv of the total operand count, producing incorrect slices that
1890 // would cause out-of-bounds accesses below.
1891 if (getInputs().size() != static_cast<size_t>(getNumDpsInputs()))
1892 return emitOpError()
1893 << "expected equal number of inputs and outputs (required by "
1894 "SameVariadicOperandSize), got "
1895 << getNumDpsInputs() << " input(s) and " << getNumDpsInits()
1896 << " output(s)";
1897
1898 if (getInputs().empty())
1899 return emitOpError() << "expected at least one input";
1900 if (getInits().empty())
1901 return emitOpError() << "expected at least one output";
1902
1903 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1904 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1905 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1906 return emitOpError() << "expects all inputs to have the same shapes. "
1907 "Shape at input-index "
1908 << i
1909 << " is not equal to the shape at input-index 0.";
1910 }
1911 }
1912 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1913 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1914 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1915 return emitOpError() << "expects all outputs to have the same shapes. "
1916 "Shape at output-index "
1917 << i
1918 << " is not equal to the shape at output-index 0.";
1919 }
1920 }
1921 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1922 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1923
1924 DenseSet<int64_t> dimensionsToReduce;
1925 for (int64_t dimension : dimensionsRef) {
1926 if (dimension < 0 || dimension >= inputType.getRank()) {
1927 return emitOpError()
1928 << "dimensions for reduction should be in the range [0, "
1929 << inputType.getRank() - 1 << "].";
1930 }
1931 dimensionsToReduce.insert(dimension);
1932 }
1933
1934 auto inputDims = inputType.getShape();
1935 auto initDims = initType.getShape();
1936
1937 // Input dimensions that will be left after the reduction.
1938 SmallVector<int64_t> reducedInputDims;
1939 for (const auto &en : llvm::enumerate(inputDims)) {
1940 if (!dimensionsToReduce.count(en.index()))
1941 reducedInputDims.push_back(en.value());
1942 }
1943
1944 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1945 return emitOpError() << "number of dimensions after reduction "
1946 << reducedInputDims.size()
1947 << " doesn't match the init rank "
1948 << initType.getRank();
1949 }
1950
1951 if (reducedInputDims != initDims)
1952 return emitOpError() << "init dimensions [" << initDims
1953 << "] doesn't match input dimensions after reduction ["
1954 << reducedInputDims << "]";
1955
1956 Block *block = getBody();
1957 if (block->getNumArguments() != this->getNumOperands())
1958 return emitOpError()
1959 << "mismatching number of operands and block arguments";
1960
1961 // Check that the first block arguments match the element type of the inputs.
1962 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1963 Type inputElementType =
1964 llvm::cast<ShapedType>(input.getType()).getElementType();
1965 if (inputElementType != bbArg.getType())
1966 return emitOpError()
1967 << "input element type " << inputElementType
1968 << " does not match corresponding block argument type "
1969 << bbArg.getType();
1970 }
1971
1972 // Check that the last block arguments match the element type of the outputs.
1973 for (auto [output, bbArg] : llvm::zip(
1974 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1975 auto outputElementType =
1976 llvm::cast<ShapedType>(output.getType()).getElementType();
1977 if (outputElementType != bbArg.getType())
1978 return emitOpError()
1979 << "output element type " << outputElementType
1980 << " does not match corresponding block argument type "
1981 << bbArg.getType();
1982 }
1983 return success();
1984}
1985
1986//===----------------------------------------------------------------------===//
1987// TransposeOp
1988//===----------------------------------------------------------------------===//
1989
1990static void buildIdentityRegion(OpBuilder &builder, Location loc,
1991 Region &region, ValueRange inputs,
1992 ValueRange outputs) {
1993 buildGenericRegion(builder, loc, region, inputs, outputs,
1994 [](OpBuilder &b, Location loc, ValueRange args) {
1995 if (!args.empty())
1996 linalg::YieldOp::create(b, loc, args[0]);
1997 });
1998}
1999
2000void TransposeOp::build(::mlir::OpBuilder &builder,
2001 ::mlir::OperationState &result, Value input, Value init,
2002 DenseI64ArrayAttr permutation,
2003 ArrayRef<NamedAttribute> attributes) {
2004 result.addOperands(input);
2005 result.addOperands(init);
2006 result.addAttribute(getPermutationAttrName(result.name), permutation);
2007 result.addAttributes(attributes);
2008
2009 // Add output types for `RankedTensorType` output arguments.
2010 Type initType = init.getType();
2011 if (llvm::isa<RankedTensorType>(initType))
2012 result.addTypes(initType);
2013
2014 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2015 init);
2016}
2017
2018void TransposeOp::build(::mlir::OpBuilder &builder,
2019 ::mlir::OperationState &result, Value input, Value init,
2020 ArrayRef<int64_t> permutation,
2021 ArrayRef<NamedAttribute> attributes) {
2022 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2023 attributes);
2024}
2025
2026ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2028 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2029 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2030 })))
2031 return failure();
2032
2033 OpBuilder builder(parser.getContext());
2034 buildIdentityRegion(builder, result.location, *result.addRegion(),
2035 /*inputs=*/result.operands,
2036 /*outputs=*/{});
2037 return success();
2038}
2039
2040void TransposeOp::getAsmResultNames(
2041 function_ref<void(Value, StringRef)> setNameFn) {
2042 if (!getResults().empty())
2043 setNameFn(getResults().front(), "transposed");
2044}
2045
2046void TransposeOp::print(OpAsmPrinter &p) {
2047 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2048 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2049 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2050}
2051
2052LogicalResult TransposeOp::verify() {
2053 ArrayRef<int64_t> permutationRef = getPermutation();
2054
2055 if (!isPermutationVector(permutationRef))
2056 return emitOpError("permutation is not valid");
2057
2058 auto inputType = getInput().getType();
2059 auto initType = getInit().getType();
2060
2061 int64_t rank = inputType.getRank();
2062
2063 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2064 "init")))
2065 return failure();
2066
2067 if (rank != static_cast<int64_t>(permutationRef.size()))
2068 return emitOpError() << "size of permutation " << permutationRef.size()
2069 << " does not match the argument rank " << rank;
2070
2071 auto inputDims = inputType.getShape();
2072 auto initDims = initType.getShape();
2073
2074 for (int64_t i = 0; i < rank; ++i) {
2075 int64_t inputDim = inputDims[permutationRef[i]];
2076 int64_t initDim = initDims[i];
2077
2078 if (inputDim != initDim) {
2079 return emitOpError() << "dim(result, " << i << ") = " << initDim
2080 << " doesn't match dim(input, permutation[" << i
2081 << "]) = " << inputDim;
2082 }
2083 }
2084
2085 return success();
2086}
2087
2088SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2089 int64_t rank = getInit().getType().getRank();
2090 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2091}
2092
2093ArrayAttr TransposeOp::getIndexingMaps() {
2094 Builder builder(getContext());
2095 int64_t rank = getInit().getType().getRank();
2096 return builder.getAffineMapArrayAttr(
2098 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2099 builder.getMultiDimIdentityMap(rank)});
2100}
2101
2102void TransposeOp::getEffects(
2103 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2104 &effects) {
2105 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2106}
2107
2108Speculation::Speculatability TransposeOp::getSpeculatability() {
2109 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2110}
2111
2112LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2113 SmallVectorImpl<OpFoldResult> &result) {
2114 // Only the tensor type is supported.
2115 if (!isa<TensorType>(getInput().getType()))
2116 return failure();
2117
2118 // Single dimension transpose.
2119 if (getPermutation().empty()) {
2120 result.push_back(getInput());
2121 return success();
2122 }
2123 // Identity permutation.
2124 if (isIdentityPermutation(getPermutation())) {
2125 result.push_back(getInput());
2126 return success();
2127 }
2128
2129 return failure();
2130}
2131
2132/// Fold transpose with transpose.
2133struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2134 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2135
2136 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2137 PatternRewriter &rewriter) const override {
2138 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2139 if (!defTransposeOp)
2140 return failure();
2141 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2142 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2143 SmallVector<int64_t> foldedPerms;
2144 foldedPerms.reserve(perms.size());
2145 for (int64_t perm : perms)
2146 foldedPerms.push_back(defPerms[perm]);
2147
2148 rewriter.replaceOpWithNewOp<TransposeOp>(
2149 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2150 foldedPerms);
2151 return success();
2152 }
2153};
2154
2155/// This pattern canonicalize transpose by swapping the order of
2156/// broadcast and transpose:
2157/// transpose(broadcast(input)) -> broadcast(transpose(input))
2158struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2159 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2160
2161 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2162 PatternRewriter &rewriter) const override {
2163 Value input = transposeOp.getInput();
2164 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2165 if (!input.hasOneUse() || !broadcastOp)
2166 return failure();
2167
2168 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2169 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2170
2171 // Get new perms and new dimensions.
2172 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2174 SmallVector<int64_t> resultDimensions;
2175 unsigned dimensionSize = dimensions.size();
2176 for (unsigned i = 0; i < dimensionSize; ++i)
2177 resultDimensions.push_back(invertPerm[dimensions[i]]);
2178
2179 // Create transpose result.
2180 Value broadcastInput = broadcastOp.getInput();
2181 Location loc = transposeOp.getLoc();
2182 MLIRContext *ctx = transposeOp.getContext();
2184 auto broadcastInputTy =
2185 mlir::cast<RankedTensorType>(broadcastInput.getType());
2186 unsigned inputRank = broadcastInputTy.getRank();
2187 for (unsigned i = 0; i < inputRank; ++i) {
2188 if (broadcastInputTy.isDynamicDim(i)) {
2189 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2190 ->getResult(0));
2191 } else {
2192 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2193 broadcastInputTy.getDimSize(i)));
2194 }
2195 }
2196 SmallVector<OpFoldResult> transposeResultShapes =
2197 applyPermutation(dims, resultPerms);
2198 Value transposeInit = tensor::EmptyOp::create(
2199 rewriter, transposeOp.getLoc(), transposeResultShapes,
2200 broadcastInputTy.getElementType());
2201
2202 // Create broadcast(transpose(input)).
2203 Value transposeResult =
2204 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2205 transposeInit, resultPerms)
2206 ->getResult(0);
2207 rewriter.replaceOpWithNewOp<BroadcastOp>(
2208 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2209 return success();
2210 }
2211};
2212
2213void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2214 MLIRContext *context) {
2215 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2216}
2217
2218//===----------------------------------------------------------------------===//
2219// BroadcastOp
2220//===----------------------------------------------------------------------===//
2221
2222void BroadcastOp::build(::mlir::OpBuilder &builder,
2223 ::mlir::OperationState &result, Value input, Value init,
2224 DenseI64ArrayAttr dimensions,
2225 ArrayRef<NamedAttribute> attributes) {
2226 result.addOperands(input);
2227 result.addOperands(init);
2228 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2229 result.addAttributes(attributes);
2230
2231 // Add output types for `RankedTensorType` output arguments.
2232 Type initType = init.getType();
2233 if (llvm::isa<RankedTensorType>(initType))
2234 result.addTypes(initType);
2235
2236 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2237 init);
2238}
2239
2240void BroadcastOp::build(::mlir::OpBuilder &builder,
2241 ::mlir::OperationState &result, Value input, Value init,
2242 ArrayRef<int64_t> dimensions,
2243 ArrayRef<NamedAttribute> attributes) {
2244 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2245 attributes);
2246}
2247
2248ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2250 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2251 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2252 })))
2253 return failure();
2254
2255 OpBuilder builder(parser.getContext());
2256 buildIdentityRegion(builder, result.location, *result.addRegion(),
2257 /*inputs=*/result.operands,
2258 /*outputs=*/{});
2259 return success();
2260}
2261
2262void BroadcastOp::getAsmResultNames(
2263 function_ref<void(Value, StringRef)> setNameFn) {
2264 if (!getResults().empty())
2265 setNameFn(getResults().front(), "broadcasted");
2266}
2267
2268void BroadcastOp::print(OpAsmPrinter &p) {
2269 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2270 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2271 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2272}
2273
2274LogicalResult BroadcastOp::verify() {
2275 ArrayRef<int64_t> dimensionsRef = getDimensions();
2276
2277 auto inputType = getInput().getType();
2278 auto initType = getInit().getType();
2279
2280 int64_t inputRank = inputType.getRank();
2281 int64_t initRank = initType.getRank();
2282
2283 auto inputShape = inputType.getShape();
2284 auto initShape = initType.getShape();
2285
2286 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2287 return emitOpError() << "input rank plus added dimensions does not "
2288 "match init rank. input rank: "
2289 << inputRank
2290 << ", dimensions size: " << dimensionsRef.size()
2291 << ", init rank: " << initRank;
2292
2293 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2294 if (dim < 0 || dim >= initRank)
2295 return emitOpError() << "dimension " << idx
2296 << " is out of range. expected range: [0, "
2297 << initRank - 1 << "], got: " << dim;
2298 }
2299
2300 // Mapping from input dims to init dims.
2301 SmallVector<int64_t> dimMap;
2302 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2303 if (!llvm::is_contained(dimensionsRef, dim))
2304 dimMap.push_back(dim);
2305 }
2306
2307 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2308 // This dimensions is mapped from the input. Init and input dims should
2309 // match.
2310 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2311 return emitOpError() << "input dim " << inputDimIdx
2312 << " should match init dim " << initDimIdx
2313 << ". input: " << inputShape[inputDimIdx]
2314 << ", init: " << initShape[initDimIdx];
2315 }
2316
2317 return success();
2318}
2319
2320SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2321 int64_t rank = getInit().getType().getRank();
2322 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2323}
2324
2325ArrayAttr BroadcastOp::getIndexingMaps() {
2326 Builder builder(getContext());
2327 int64_t rank = getInit().getType().getRank();
2328 return builder.getAffineMapArrayAttr(
2329 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2330 builder.getMultiDimIdentityMap(rank)});
2331}
2332
2333void BroadcastOp::getEffects(
2334 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2335 &effects) {
2336 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2337}
2338
2339Speculation::Speculatability BroadcastOp::getSpeculatability() {
2340 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2341}
2342
2343/// Fold back-to-back broadcasts together.
2344struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2345 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2346
2347 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2348 PatternRewriter &rewriter) const override {
2349 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2350 if (!defBroadcastOp)
2351 return failure();
2352 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2353 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2354 SmallVector<int64_t> foldedDims(dimensions);
2355 Value init = broadcastOp.getInit();
2356 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2357 // Mapping from input dims to init dims.
2358 SmallVector<int64_t> dimMap;
2359 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2360 if (!llvm::is_contained(dimensions, dim))
2361 dimMap.push_back(dim);
2362 }
2363 for (auto dim : defDimensions)
2364 foldedDims.push_back(dimMap[dim]);
2365
2366 llvm::sort(foldedDims);
2367 rewriter.replaceOpWithNewOp<BroadcastOp>(
2368 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2369 return success();
2370 }
2371};
2372
2373void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2374 MLIRContext *context) {
2375 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2376}
2377
2378//===----------------------------------------------------------------------===//
2379// YieldOp
2380//===----------------------------------------------------------------------===//
2381
2382void linalg::YieldOp::print(OpAsmPrinter &p) {
2383 if (getNumOperands() > 0)
2384 p << ' ' << getOperands();
2385 p.printOptionalAttrDict((*this)->getAttrs());
2386 if (getNumOperands() > 0)
2387 p << " : " << getOperandTypes();
2388}
2389
2390ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2391 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2392 SmallVector<Type, 2> types;
2393 SMLoc loc = parser.getCurrentLocation();
2394 return failure(parser.parseOperandList(opInfo) ||
2395 parser.parseOptionalAttrDict(result.attributes) ||
2396 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2397 parser.resolveOperands(opInfo, types, loc, result.operands));
2398}
2399
2400// Check the operand number and types must match the element types of the
2401// LinalgOp interface's shaped operands.
2402static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2403 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2404 return op.emitOpError("expected number of yield values (")
2405 << op.getNumOperands()
2406 << ") to match the number of inits / outs operands of the enclosing "
2407 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2408
2409 for (OpOperand &opOperand : op->getOpOperands()) {
2410 OpOperand *outputOperand =
2411 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2412 Type elementType = outputOperand->get().getType();
2413 if (isa<MemRefType, RankedTensorType>(elementType))
2414 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2415 if (opOperand.get().getType() != elementType)
2416 return op.emitOpError("type of yield operand ")
2417 << (opOperand.getOperandNumber() + 1) << " ("
2418 << opOperand.get().getType() << ") doesn't match "
2419 << "the element type of the enclosing linalg.generic op ("
2420 << elementType << ")";
2421 }
2422 return success();
2423}
2424
2425LogicalResult linalg::YieldOp::verify() {
2426 auto *parentOp = (*this)->getParentOp();
2427 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2428 return emitOpError("expected single non-empty parent region");
2429
2430 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2431 return verifyYield(*this, linalgOp);
2432
2433 return emitOpError("expected parent op with LinalgOp interface");
2434}
2435
2436//===----------------------------------------------------------------------===//
2437// IndexOp
2438//===----------------------------------------------------------------------===//
2439
2440LogicalResult IndexOp::verify() {
2441 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2442 if (!linalgOp)
2443 return emitOpError("expected parent op with LinalgOp interface");
2444 if (linalgOp.getNumLoops() <= getDim())
2445 return emitOpError("expected dim (")
2446 << getDim() << ") to be lower than the number of loops ("
2447 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2448 return success();
2449}
2450
2451OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2452 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2453 // Bail out if `linalg.index` does not have a proper parent yet at this
2454 // point, e.g., when calling `createOrFold` during IR construction in
2455 // `genericOp::build`.
2456 if (!linalgOp)
2457 return OpFoldResult{};
2458
2459 // Index of unit dims is always 0.
2460 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2461 uint64_t dim = getDim();
2462 assert(dim < loopBounds.size() && "Dim is out of bounds");
2463 if (loopBounds[dim] == 1)
2464 return IntegerAttr::get(IndexType::get(getContext()), 0);
2465
2466 return OpFoldResult{};
2467}
2468
2469/////// Operations corresponding to library calls defined with Tablegen ////////
2470
2471#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2472
2473#define GET_OP_CLASSES
2474#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2475
2476#define GET_OP_CLASSES
2477#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2478#define GET_OP_CLASSES
2479#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2480
2481AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2482 unsigned rank,
2483 MLIRContext *context) {
2484 if (maybeMap)
2485 return *maybeMap;
2486 if (rank == 0)
2487 return AffineMap::get(context);
2488 return AffineMap::getMultiDimIdentityMap(rank, context);
2489}
2490
2492mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2493 MLIRContext *context) {
2495 res.reserve(num);
2496 for (unsigned i = 0; i < num; ++i)
2497 res.push_back(getAffineDimExpr(startIdx++, context));
2498 return res;
2499}
2500
2503 auto rangeA = llvm::make_range(a.begin(), a.end());
2504 auto rangeB = llvm::make_range(b.begin(), b.end());
2505 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2506 return llvm::to_vector<4>(concatRanges);
2507}
2508
2509static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2510 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2511 ss << "view";
2512 for (auto size : memref.getShape())
2513 if (size < 0)
2514 ss << "sx";
2515 else
2516 ss << size << "x";
2517 if (failed(appendMangledType(ss, memref.getElementType())))
2518 return failure();
2519 if (auto as = memref.getMemorySpace()) {
2520 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2521 ss << "as" << attr.getInt();
2522 else
2523 return failure();
2524 }
2525 return success();
2526 }
2527 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2528 ss << "vector";
2529 llvm::interleave(
2530 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2531 if (failed(appendMangledType(ss, vec.getElementType())))
2532 return failure();
2533 return success();
2534 }
2536 ss << t;
2537 return success();
2538 }
2539 return failure();
2540}
2541
2543 assert(isa<LinalgOp>(op));
2544 std::string name(op->getName().getStringRef().str());
2545 std::string fun = "";
2546 for (NamedAttribute kv : op->getAttrs()) {
2547 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2548 fun = stringifyEnum(ufa.getValue()).str() + "_";
2549 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2550 fun = stringifyEnum(bfa.getValue()).str() + "_";
2551 }
2552 }
2553 name.reserve(128);
2554 llvm::replace(name, '.', '_');
2555 llvm::raw_string_ostream ss(name);
2556 ss << "_" << fun;
2557 for (Type t : op->getOperandTypes()) {
2558 if (failed(appendMangledType(ss, t)))
2559 return std::string();
2560 ss << "_";
2561 }
2562 name.pop_back();
2563 return name;
2564}
2565
2566//===----------------------------------------------------------------------===//
2567// Canonicalizers and Folders.
2568//===----------------------------------------------------------------------===//
2569
2570namespace {
2571struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2573
2574 LogicalResult matchAndRewrite(LinalgOp op,
2575 PatternRewriter &rewriter) const override {
2576 for (OpOperand &opOperand : op->getOpOperands()) {
2577 // Linalg "inputs" may be either tensor or memref type.
2578 // tensor<0xelt_type> is a convention that may not always mean
2579 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2580 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2581 if (!mt)
2582 continue;
2583 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2584 rewriter.eraseOp(op);
2585 return success();
2586 }
2587 }
2588 return failure();
2589 }
2590};
2591
2592/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2593/// result that is more static than the linalg op.
2594struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2595 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2596
2597 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2598 PatternRewriter &rewriter) const override {
2599 if (!tensor::canFoldIntoProducerOp(castOp))
2600 return failure();
2601
2602 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2603 if (!linalgOp)
2604 return failure();
2605
2606 // Cast can be in conditionally reachable region, if which case folding will
2607 // generate invalid code. Only conservatively fold ops in same block for
2608 // now.
2609 if (castOp->getBlock() != linalgOp->getBlock())
2610 return failure();
2611
2612 OpBuilder::InsertionGuard guard(rewriter);
2613 rewriter.setInsertionPoint(linalgOp);
2614
2615 Location loc = linalgOp.getLoc();
2616 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2617 unsigned resultNumber = resultValue.getResultNumber();
2618 auto resultType =
2619 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2620 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2621 // going from a more dynamic shape to a less dynamic shape. If the producer
2622 // for this cast, i.e. producer of the out operand, is also an operation
2623 // that folds with tensor.cast consumer (like this pattern), the cast will
2624 // continue to propagate as far up the stack as it can go.
2625 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2626 Value newOperand =
2627 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2628 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2629 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2630 linalgOp.getDpsInits().end());
2631 outputOperands[resultNumber] = newOperand;
2632 newOperands.append(outputOperands.begin(), outputOperands.end());
2633
2634 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2635 linalgOp->result_type_end());
2636 resultTypes[resultNumber] = resultType;
2637 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2638
2639 // Create a tensor.cast operation back to the original type.
2640 Value castBack = tensor::CastOp::create(
2641 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2642
2643 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2644 results[resultNumber] = castBack;
2645 rewriter.replaceOp(linalgOp, results);
2646 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2647 return success();
2648 }
2649};
2650
2651/// For each of the operand in `operands` this function maps the static sizes of
2652/// dimensions to their affine dim expressions.
2653static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2654 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2655 for (OpOperand &opOperand : operands) {
2656 if (linalgOp.isScalar(&opOperand))
2657 continue;
2658 Value src = opOperand.get();
2659 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2660 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2661
2662 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2663 // `tensor.cast` operation and source of the cast operation has a static
2664 // shape, then assign it to the `sourceShape`.
2665 auto *parentOp = src.getDefiningOp();
2666 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2667 if (parentOp) {
2668 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2669 Value castSource = castOp.getSource();
2670 auto castSourceType =
2671 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2672 if (castSourceType && castSourceType.hasStaticShape())
2673 sourceShape = castSourceType.getShape();
2674 }
2675 }
2676
2677 // If the source shape's dimension has a static shape, map the affine dim
2678 // expression to the known static size.
2679 for (unsigned i = 0; i < sourceShape.size(); i++) {
2680 if (sourceType.isDynamicDim(i))
2681 continue;
2682 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2683 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2684 }
2685 }
2686}
2687
2688/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2689/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2690/// their result types is stored in `resultTypes`. If `opOperand` requires no
2691/// change then `changeNeeded` is false and same operand is added in the
2692/// `newOperands` list.
2693static void createNewOperandWithStaticSizes(
2694 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2695 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2696 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2697 bool &changeNeeded) {
2698 Value src = opOperand->get();
2699 newOperands.push_back(src);
2700 if (linalgOp.isScalar(opOperand))
2701 return;
2702 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2703 Type resultType = sourceType;
2704 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2705 resultTypes.push_back(resultType);
2706 return;
2707 }
2708 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2709 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2710 SmallVector<int64_t> newShape;
2711 // If operand is updated with new shape, `newOperandNeeded` will be
2712 // true.
2713 bool newOperandNeeded = false;
2714 for (unsigned i = 0; i < sourceShape.size(); i++) {
2715 int64_t dimShape = sourceShape[i];
2716 AffineExpr dimExpr = sourceMap.getResult(i);
2717 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2718 newShape.push_back(dimShape);
2719 continue;
2720 }
2721 // Dimension has a dynamic shape and corresponding affine dim
2722 // expression is present in the map. So assign the size for the
2723 // given affine dim expression to the dimension.
2724 newShape.push_back(affineExprToSize[dimExpr]);
2725 newOperandNeeded = true;
2726 }
2727 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2728 sourceType.getEncoding());
2729 if (newOperandNeeded) {
2730 changeNeeded = true;
2731 // Get the new operand value given its size and element type by
2732 // casting it.
2733 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2734 unsigned index = opOperand->getOperandNumber();
2735 newOperands[index] = newOperand;
2736 }
2737 if (linalgOp.isDpsInit(opOperand))
2738 resultTypes.push_back(resultType);
2739}
2740
2741/// Static shapes for the operands can be inferred if any one of the operands
2742/// have a static shape. This can be done by referring to the affine dim
2743/// expressions for the operand.
2744struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2745 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2746
2747 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2748 PatternRewriter &rewriter) const override {
2749 if (!linalgOp.hasPureTensorSemantics())
2750 return failure();
2751
2752 // Maps must be projected permutations.
2753 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2754 return !map.isProjectedPermutation();
2755 }))
2756 return failure();
2757
2758 // Maps affine dim expressions to the static size of that dimension.
2759 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2760 Location loc = linalgOp.getLoc();
2761
2762 // For each of the affine dim expression, check if the size is known. If
2763 // known add that in the map.
2764 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2765
2766 SmallVector<Value> newOperands;
2767 SmallVector<Type> resultTypes;
2768
2769 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2770 // change in their types.
2771 bool changeNeeded = false;
2772 newOperands.reserve(linalgOp->getNumOperands());
2773 resultTypes.reserve(linalgOp.getNumDpsInits());
2774
2775 // Iterate over all the operands and update the static sizes.
2776 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2777 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2778 affineExprToSize, linalgOp, newOperands,
2779 resultTypes, changeNeeded);
2780 }
2781
2782 // If the generic op has all the required static information, no
2783 // canonicalization needed.
2784 if (!changeNeeded)
2785 return failure();
2786
2787 // Clone op.
2788 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2789 SmallVector<Value> replacements;
2790 replacements.reserve(newOp->getNumResults());
2791 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2792 Value newResult = std::get<1>(it);
2793 Value oldResult = std::get<0>(it);
2794 Type newType = newResult.getType();
2795 Type oldType = oldResult.getType();
2796 replacements.push_back(
2797 (newType != oldType)
2798 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2799 : newResult);
2800 }
2801 rewriter.replaceOp(linalgOp, replacements);
2802 return success();
2803 }
2804};
2805
2806} // namespace
2807
2808// All named ops canonicalizers and folders are auto-generated in the
2809// .cpp.inc.
2810
2811//===----------------------------------------------------------------------===//
2812// SoftmaxOp
2813//===----------------------------------------------------------------------===//
2814
2815LogicalResult SoftmaxOp::verify() {
2816 ShapedType inputType = getInputOperandType();
2817 ShapedType outputType = getOutputOperandType();
2818
2819 ArrayRef<int64_t> inputShape = inputType.getShape();
2820 ArrayRef<int64_t> outputShape = outputType.getShape();
2821 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2822 return emitOpError("incompatible output shape");
2823
2824 int64_t inputRank = getInputOperandRank();
2825 int64_t dimension = getDimension();
2826 if ((dimension < 0) || (dimension >= inputRank))
2827 return emitOpError("incorrect dimension specified");
2828
2829 return success();
2830}
2831
2832SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2833 int64_t operandRank = getInputOperandRank();
2834 SmallVector<Range> loopBounds(operandRank);
2835 Location loc = getLoc();
2836 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2837 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2838 Value source = getInput();
2839 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2840 loopBounds[dim].offset = zero;
2841 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2842 loopBounds[dim].stride = one;
2843 }
2844 return loopBounds;
2845}
2846
2847SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2848 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2849 utils::IteratorType::parallel);
2850 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2851 return iteratorTypes;
2852}
2853
2854FailureOr<TilingResult>
2855SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2856 ArrayRef<OpFoldResult> offsets,
2857 ArrayRef<OpFoldResult> sizes) {
2858 int64_t rank = getInputOperandRank();
2859 auto oneAttr = builder.getI64IntegerAttr(1);
2860 SmallVector<OpFoldResult> strides(rank, oneAttr);
2861 SmallVector<Value> tiledOperands;
2862 Operation *inputSlice =
2863 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2864 if (!inputSlice) {
2865 return emitOpError("failed to compute input slice");
2866 }
2867 tiledOperands.emplace_back(inputSlice->getResult(0));
2868 Operation *outputSlice =
2869 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2870 if (!outputSlice) {
2871 return emitOpError("failed to compute output slice");
2872 }
2873 tiledOperands.emplace_back(outputSlice->getResult(0));
2874
2875 SmallVector<Type, 4> resultTypes;
2876 if (hasPureTensorSemantics())
2877 resultTypes.push_back(tiledOperands[1].getType());
2878 Operation *tiledOp =
2879 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2880
2881 return TilingResult{
2882 {tiledOp},
2883 SmallVector<Value>(tiledOp->getResults()),
2884 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2885}
2886
2887LogicalResult SoftmaxOp::getResultTilePosition(
2888 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2889 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2890 SmallVector<OpFoldResult> &resultSizes) {
2891 if (resultNumber == 0) {
2892 resultOffsets.assign(offsets.begin(), offsets.end());
2893 resultSizes.assign(sizes.begin(), sizes.end());
2894 return success();
2895 }
2896 return failure();
2897}
2898
2899// cast(dynamic) -> static.
2900LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2901 return memref::foldMemRefCast(*this);
2902}
2903
2904LogicalResult
2905SoftmaxOp::reifyResultShapes(OpBuilder &b,
2906 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2907 SmallVector<OpFoldResult> shapes;
2908 Location loc = getOperation()->getLoc();
2909 IRRewriter rewriter(b);
2910 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2911 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2912 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2913 if (!outputShapedType.isDynamicDim(dim)) {
2914 // Static dim: Return IntegerAttr.
2915 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2916 } else {
2917 // Dynamic dim: Return Value.
2918 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2919 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2920 }
2921 }
2922 reifiedReturnShapes.emplace_back(std::move(shapes));
2923 return success();
2924}
2925
2926void SoftmaxOp::getEffects(
2927 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2928 &effects) {
2929 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2930 if (!llvm::isa<MemRefType>(operand.getType()))
2931 continue;
2932 effects.emplace_back(MemoryEffects::Read::get(),
2933 &getOperation()->getOpOperand(index), /*stage=*/0,
2934 /*effectOnFullRegion=*/true,
2936 }
2937
2938 for (OpOperand &operand : getDpsInitsMutable()) {
2939 if (!llvm::isa<MemRefType>(operand.get().getType()))
2940 continue;
2941 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2942 /*effectOnFullRegion=*/true,
2944 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2945 /*effectOnFullRegion=*/true,
2947 }
2948}
2949
2950// Helper functions for softmax decomposition.
2951// @{
2952
2953// Helper function to produce the iterator types (reduction or parallel) and
2954// affine maps for the iterators used in the decomposition of softmax.
2955// This method creates:
2956// If allParallel == true:
2957// - iterator type: {parallel, ..., parallel}
2958// - affine maps:
2959// -- identity with inputRank dimensions.
2960// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2961// where N == inputRank.
2962//
2963// If allParallel == false:
2964// - iterator type at dim(i) == parallel for i != \p dim and
2965// dim(dim) == reduction.
2966// - affine map:
2967// -- identity with inputRank dimensions.
2968// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2969// where N == inputRank.
2970static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2972 int64_t dim, bool allParallel = false) {
2973 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2974 utils::IteratorType::parallel);
2975 if (!allParallel)
2976 iteratorTypes[dim] = utils::IteratorType::reduction;
2977 MLIRContext *ctxt = builder.getContext();
2978 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2979 SmallVector<AffineExpr, 2> affineExprs;
2980 for (int i = 0; i < inputRank; i++) {
2981 if (i != dim)
2982 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2983 }
2984 auto reductionMap =
2985 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2986 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2987 return std::make_tuple(iteratorTypes, indexingMaps);
2988}
2989
2990// Helper function to produce a linalg.generic that computes a reduction on
2991// dimension \p dim with the operation type \p T.
2992template <typename T>
2993static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2994 int64_t dim) {
2995 auto inputType = cast<ShapedType>(input.getType());
2996 ArrayRef<int64_t> inputShape = inputType.getShape();
2997 int64_t inputRank = inputShape.size();
2998 auto [iteratorTypes, indexingMaps] =
2999 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
3000 assert(indexingMaps.size() == 2 &&
3001 "We should have two maps: 1 for the input, 1 for the output");
3002 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3003
3004 auto genericOp = linalg::GenericOp::create(
3005 builder, loc, output.getType(), input, output, indexingMaps,
3006 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
3007 Value result = T::create(b, loc, args[0], args[1]);
3008 linalg::YieldOp::create(b, loc, result);
3009 });
3010 return genericOp.getResult(0);
3011}
3012
3013/// Produce a linalg generic that computes the second step of the softmax
3014/// decomposition: res = exp(input - max), where \p max is the max of \p input
3015/// on dimension \p dim.
3016static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3017 Value max, Value output, int64_t dim) {
3018 auto inputType = cast<ShapedType>(input.getType());
3019 ArrayRef<int64_t> inputShape = inputType.getShape();
3020 int64_t inputRank = inputShape.size();
3021 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3022 builder, inputRank, dim, /*allParallel=*/true);
3023 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3024 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3025 // Add the affine map for the output argument.
3026 indexingMaps.push_back(indexingMaps[0]);
3027 auto genericOp = linalg::GenericOp::create(
3028 builder, loc, input.getType(), ValueRange{input, max}, output,
3029 indexingMaps, iteratorTypes,
3030 [&](OpBuilder &b, Location loc, ValueRange args) {
3031 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3032 Value result = math::ExpOp::create(b, loc, diff);
3033 linalg::YieldOp::create(b, loc, result);
3034 });
3035 return genericOp.getResult(0);
3036}
3037
3038/// Produce a linalg generic that computes the final step of the softmax
3039/// decomposition.
3040/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3041/// yield n / d
3042/// }
3043static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3044 Value denominator, Value output, int64_t dim) {
3045 auto inputType = cast<ShapedType>(numerator.getType());
3046 ArrayRef<int64_t> inputShape = inputType.getShape();
3047 int64_t inputRank = inputShape.size();
3048 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3049 builder, inputRank, dim, /*allParallel=*/true);
3050 assert(indexingMaps.size() == 2 &&
3051 "We should have one map for each input (2)");
3052 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3053 // Add the affine map for the output tensor.
3054 indexingMaps.push_back(indexingMaps[0]);
3055 auto genericOp = linalg::GenericOp::create(
3056 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3057 output, indexingMaps, iteratorTypes,
3058 [&](OpBuilder &b, Location loc, ValueRange args) {
3059 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3060 linalg::YieldOp::create(b, loc, result);
3061 });
3062 return genericOp.getResult(0);
3063}
3064// @} End helper functions for softmax decomposition.
3065
3066/// Given an N-dimensional tensor x, this method converts
3067/// softmax(x) to the following sequence of operations:
3068///
3069/// 1. Compute the max of x along dimension d. This results
3070/// in a N-1 dimensional tensor m.
3071/// m = max(x, dim = d)
3072///
3073/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3074/// a N dimensional tensor z.
3075/// z = exp(x - m)
3076///
3077/// 3. Compute the sum of z along dimension d. This results in
3078/// a N-1 dimensional tensor l.
3079/// l = sum(z, dim = d)
3080///
3081/// 4. Divide z and l. This gives the N-dimensional softmax.
3082/// softmax = z / l
3083///
3084FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3085 OpBuilder::InsertionGuard guard(b);
3086 b.setInsertionPoint(*this);
3087 Location loc = getLoc();
3088 Value input = getInput();
3089 ShapedType inputType = getInputOperandType();
3090 Type elementType = inputType.getElementType();
3091 int64_t reductionDim = getDimension();
3092 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3093 Value output = getOutput();
3094 dims.erase(dims.begin() + reductionDim);
3095 // Step 1: Compute max along dim.
3096 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3097 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3098 elementType, b, loc,
3099 /*useOnlyFiniteValue=*/true);
3100 Value neutralForMaxFInit =
3101 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3102 .result();
3103 Value max =
3104 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3105
3106 // Step 2: Subtract max from input and exponentiate.
3107 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3108
3109 // Step 3: Compute sum along dim.
3110 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3111 b, loc, /*useOnlyFiniteValue=*/true);
3112 Value zeroInit =
3113 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3114 Value denominator =
3115 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3116
3117 // Step 4: Compute softmax.
3118 Value result =
3119 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3120 return SmallVector<Value>{result};
3121}
3122
3123//===----------------------------------------------------------------------===//
3124// WinogradFilterTransformOp
3125//===----------------------------------------------------------------------===//
3126
3127LogicalResult WinogradFilterTransformOp::verify() {
3128 auto filterType = cast<ShapedType>(getFilter().getType());
3129 ArrayRef<int64_t> filterShape = filterType.getShape();
3130 int64_t filterH = filterShape[getFilterHDim()];
3131 int64_t filterW = filterShape[getFilterWDim()];
3132 WinogradConv2DFmr fmr = getFmr();
3133 int64_t m, r;
3134 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3135
3136 if (filterH != r && filterH != 1)
3137 return emitOpError("expect filter height either equals to r or 1");
3138 if (filterW != r && filterW != 1)
3139 return emitOpError("expect filter width either equals to r or 1");
3140 if (filterH == 1 && filterW == 1)
3141 return emitOpError("expect either filter height or width equals to r");
3142
3143 SmallVector<int64_t> expectedOutputShape;
3144 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3145 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3146 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3147 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3148
3149 auto outputType = cast<ShapedType>(getOutput().getType());
3150 ArrayRef<int64_t> outputShape = outputType.getShape();
3151 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3152 return emitOpError("the output shape is not expected");
3153 }
3154 return success();
3155}
3156
3157SmallVector<Range>
3158WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3159 Location loc = getLoc();
3160 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3161 IntegerAttr oneAttr = builder.getIndexAttr(1);
3162 Value filter = getFilter();
3163 int64_t filterRank = getFilterOperandRank();
3164 SmallVector<Range> loopBounds(filterRank);
3165 for (unsigned dim = 0; dim < filterRank; ++dim) {
3166 loopBounds[dim].offset = zeroAttr;
3167 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3168 loopBounds[dim].stride = oneAttr;
3169 }
3170 return loopBounds;
3171}
3172
3173SmallVector<utils::IteratorType>
3174WinogradFilterTransformOp::getLoopIteratorTypes() {
3175 int64_t filterRank = getFilterOperandRank();
3176 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3177 utils::IteratorType::parallel);
3178 return iteratorTypes;
3179}
3180
3181LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3182 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3183 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3184 SmallVector<OpFoldResult> &resultSizes) {
3185 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3186 ShapedType filterType = getFilterOperandType();
3187 ArrayRef<int64_t> filterShape = filterType.getShape();
3188 int64_t filterH = filterShape[getFilterHDim()];
3189 int64_t filterW = filterShape[getFilterWDim()];
3190 WinogradConv2DFmr fmr = getFmr();
3191 int64_t m, r;
3192 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3193 int64_t alpha = m + r - 1;
3194 int64_t alphaH = filterH != 1 ? alpha : 1;
3195 int64_t alphaW = filterW != 1 ? alpha : 1;
3196 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3197 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3198
3199 resultOffsets.append(
3200 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3201 resultSizes.append(
3202 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3203
3204 return success();
3205}
3206
3207/// Implement tiling for winograd_filter_transform
3208/// The input of winograd_filter_transform is (F, KH, KW, C).
3209/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3210/// Users can specify the tile sizes of F and C.
3211/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3212/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3213FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3214 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3215 ArrayRef<OpFoldResult> sizes) {
3216 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3217 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3218 ShapedType filterType = getFilterOperandType();
3219 ArrayRef<int64_t> filterShape = filterType.getShape();
3220 int64_t filterH = filterShape[getFilterHDim()];
3221 int64_t filterW = filterShape[getFilterWDim()];
3222 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3223 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3224 SmallVector<Value> tiledOperands;
3225 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3226
3227 sliceOffsets.append(
3228 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3229 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3230 sizes[getFilterCDim()]});
3231 int64_t filterRank = getFilterOperandRank();
3232 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3233 Location loc = getLoc();
3234 auto filterSlice = tensor::ExtractSliceOp::create(
3235 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3236 tiledOperands.emplace_back(filterSlice);
3237
3238 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3239 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3240 resultSizes)))
3241 return failure();
3242
3243 int64_t outputRank = getOutputOperandRank();
3244 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3245 auto outputSlice = tensor::ExtractSliceOp::create(
3246 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3247 tiledOperands.emplace_back(outputSlice);
3248
3249 SmallVector<Type> resultTypes;
3250 resultTypes.push_back(tiledOperands[1].getType());
3251 Operation *tiledOp =
3252 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3253
3254 return TilingResult{
3255 {tiledOp},
3256 SmallVector<Value>(tiledOp->getResults()),
3257 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3258}
3259
3260//===----------------------------------------------------------------------===//
3261// WinogradInputTransformOp
3262//===----------------------------------------------------------------------===//
3263
3264LogicalResult WinogradInputTransformOp::verify() {
3265 auto inputType = cast<ShapedType>(getInput().getType());
3266 ArrayRef<int64_t> inputShape = inputType.getShape();
3267 int64_t inputH = inputShape[getInputHDim()];
3268 int64_t inputW = inputShape[getInputWDim()];
3269 WinogradConv2DFmr fmr = getFmr();
3270 int64_t m, r;
3271 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3272 int64_t tileSize = m + r - 1;
3273
3274 auto outputType = cast<ShapedType>(getOutput().getType());
3275 ArrayRef<int64_t> outputShape = outputType.getShape();
3276 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3277 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3278
3279 SmallVector<int64_t> expectedOutputShape(6, inputH);
3280 if (ShapedType::isDynamic(inputH)) {
3281 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3282 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3283 } else {
3284 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3285 expectedOutputShape[getOutputTileHDim()] =
3286 leftTransform ? (inputH - (r - 1)) / m : inputH;
3287 }
3288 if (ShapedType::isDynamic(inputW)) {
3289 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3290 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3291 } else {
3292 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3293 expectedOutputShape[getOutputTileWDim()] =
3294 rightTransform ? (inputW - (r - 1)) / m : inputW;
3295 }
3296 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3297 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3298
3299 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3300 return emitOpError("the output shape is not expected");
3301 }
3302 return success();
3303}
3304
3305SmallVector<Range>
3306WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3307 Location loc = getLoc();
3308 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3309 IntegerAttr oneAttr = builder.getIndexAttr(1);
3310 Value output = getOutput();
3311 int64_t outputRank = getOutputOperandRank();
3312 SmallVector<Range> loopBounds(outputRank);
3313 for (unsigned dim = 0; dim < outputRank; ++dim) {
3314 loopBounds[dim].offset = zeroAttr;
3315 // alphaH, alphaW, tileH, tileW, N, C
3316 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3317 loopBounds[dim].stride = oneAttr;
3318 }
3319 return loopBounds;
3320}
3321
3322SmallVector<utils::IteratorType>
3323WinogradInputTransformOp::getLoopIteratorTypes() {
3324 int64_t outputRank = getOutputOperandRank();
3325 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3326 utils::IteratorType::parallel);
3327 return iteratorTypes;
3328}
3329
3330LogicalResult WinogradInputTransformOp::getResultTilePosition(
3331 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3332 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3333 SmallVector<OpFoldResult> &resultSizes) {
3334 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3335 ShapedType outputType = getOutputOperandType();
3336 ArrayRef<int64_t> outputShape = outputType.getShape();
3337 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3338 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3339
3340 WinogradConv2DFmr fmr = getFmr();
3341 int64_t m, r;
3342 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3343 int64_t alpha = m + r - 1;
3344 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3345 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3346
3347 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3348 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3349
3350 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3351 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3352 offsets[getOutputCDim()]});
3353 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3354 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3355 sizes[getOutputCDim()]});
3356
3357 return success();
3358}
3359
3360/// Implement tiling for winograd_input_transform
3361/// The input of winograd_input_transform is (N, H, W, C).
3362/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3363/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3364/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3365/// the values for the sizes of tileH, tileW, N, C for one tile.
3366FailureOr<TilingResult>
3367WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3368 ArrayRef<OpFoldResult> offsets,
3369 ArrayRef<OpFoldResult> sizes) {
3370 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3371 WinogradConv2DFmr fmr = getFmr();
3372 int64_t m, r;
3373 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3374
3375 ShapedType outputType = getOutputOperandType();
3376 ArrayRef<int64_t> outputShape = outputType.getShape();
3377 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3378 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3379
3380 Location loc = getLoc();
3381 MLIRContext *context = builder.getContext();
3382 auto identityAffineMap =
3383 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3384 auto offsetAffineMap =
3385 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3386 Value mappedOffsetH = affine::makeComposedAffineApply(
3387 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3388 offsets[getOutputTileHDim()]);
3389 Value mappedOffsetW = affine::makeComposedAffineApply(
3390 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3391 offsets[getOutputTileWDim()]);
3392 auto sizeAffineMap = AffineMap::get(
3393 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3394 Value mappedSizeH = affine::makeComposedAffineApply(
3395 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3396 Value mappedSizeW = affine::makeComposedAffineApply(
3397 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3398
3399 SmallVector<Value> tiledOperands;
3400 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3401
3402 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3403 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3404 sliceOffsets.append(
3405 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3406 OpFoldResult sizeH =
3407 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3408 OpFoldResult sizeW =
3409 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3410 sliceSizes.append(
3411 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3412 int64_t inputRank = getInputOperandRank();
3413 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3414 auto inputSlice = tensor::ExtractSliceOp::create(
3415 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3416 tiledOperands.emplace_back(inputSlice);
3417
3418 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3419 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3420 resultSizes)))
3421 return failure();
3422
3423 int64_t outputRank = getOutputOperandRank();
3424 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3425 auto outputSlice = tensor::ExtractSliceOp::create(
3426 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3427 tiledOperands.emplace_back(outputSlice);
3428
3429 SmallVector<Type> resultTypes;
3430 resultTypes.push_back(tiledOperands[1].getType());
3431 Operation *tiledOp =
3432 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3433
3434 return TilingResult{
3435 {tiledOp},
3436 SmallVector<Value>(tiledOp->getResults()),
3437 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3438}
3439
3440//===----------------------------------------------------------------------===//
3441// WinogradOutputTransformOp
3442//===----------------------------------------------------------------------===//
3443
3444LogicalResult WinogradOutputTransformOp::verify() {
3445 auto valueType = cast<ShapedType>(getValue().getType());
3446 ArrayRef<int64_t> valueShape = valueType.getShape();
3447 int64_t valueH = valueShape[getValueAlphaHDim()];
3448 int64_t valueW = valueShape[getValueAlphaWDim()];
3449 int64_t valueTileH = valueShape[getValueTileHDim()];
3450 int64_t valueTileW = valueShape[getValueTileWDim()];
3451 WinogradConv2DFmr fmr = getFmr();
3452 int64_t m, r;
3453 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3454 bool leftTransform = valueH != 1;
3455 bool rightTransform = valueW != 1;
3456
3457 int64_t outputRank = getOutputOperandRank();
3458 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3459 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3460 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3461 } else {
3462 if (valueH != (leftTransform ? m + r - 1 : 1))
3463 return emitOpError("expect input height equals to input tile size");
3464 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3465 }
3466 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3467 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3468 } else {
3469 if (valueW != (rightTransform ? m + r - 1 : 1))
3470 return emitOpError("expect input width equals to input tile size");
3471 expectedOutputShape[getOutputWDim()] =
3472 (rightTransform ? m : 1) * valueTileW;
3473 }
3474 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3475 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3476
3477 auto outputType = cast<ShapedType>(getOutput().getType());
3478 ArrayRef<int64_t> outputShape = outputType.getShape();
3479 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3480 return emitOpError("the output shape is not expected");
3481 }
3482 return success();
3483}
3484
3485SmallVector<Range>
3486WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3487 Location loc = getLoc();
3488 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3489 IntegerAttr oneAttr = builder.getIndexAttr(1);
3490 Value value = getValue();
3491 int64_t valueRank = getValueOperandRank();
3492 SmallVector<Range> loopBounds(valueRank);
3493 for (unsigned dim = 0; dim < valueRank; ++dim) {
3494 loopBounds[dim].offset = zeroAttr;
3495 // alphaH, alphaW, tileH, tileW, N, F
3496 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3497 loopBounds[dim].stride = oneAttr;
3498 }
3499 return loopBounds;
3500}
3501
3502SmallVector<utils::IteratorType>
3503WinogradOutputTransformOp::getLoopIteratorTypes() {
3504 int64_t valueRank = getValueOperandRank();
3505 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3506 utils::IteratorType::parallel);
3507 return iteratorTypes;
3508}
3509
3510LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3511 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3512 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3513 SmallVector<OpFoldResult> &resultSizes) {
3514 WinogradConv2DFmr fmr = getFmr();
3515 int64_t m, r;
3516 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3517
3518 Location loc = getLoc();
3519 MLIRContext *context = builder.getContext();
3520 auto identityAffineMap =
3521 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3522 auto affineMap =
3523 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3524
3525 ShapedType valueType = getValueOperandType();
3526 ArrayRef<int64_t> valueShape = valueType.getShape();
3527 int64_t valueH = valueShape[0];
3528 int64_t valueW = valueShape[1];
3529 Value mappedOffsetH = affine::makeComposedAffineApply(
3530 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3531 offsets[getValueTileHDim()]);
3532 Value mappedOffsetW = affine::makeComposedAffineApply(
3533 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3534 offsets[getValueTileWDim()]);
3535 Value mappedSizeH = affine::makeComposedAffineApply(
3536 builder, loc, affineMap, sizes[getValueTileHDim()]);
3537 Value mappedSizeW = affine::makeComposedAffineApply(
3538 builder, loc, affineMap, sizes[getValueTileWDim()]);
3539
3540 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3541 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3542 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3543 OpFoldResult sizeH =
3544 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3545 OpFoldResult sizeW =
3546 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3547
3548 resultOffsets.append(
3549 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3550 resultSizes.append(
3551 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3552 return success();
3553}
3554
3555/// Implement tiling for winograd_output_transform
3556/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3557/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3558/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3559/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3560/// for the sizes of tileH, tileW, N, F for one tile.
3561FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3562 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3563 ArrayRef<OpFoldResult> sizes) {
3564 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3565 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3566 Location loc = getLoc();
3567 SmallVector<Value> tiledOperands;
3568 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3569
3570 ShapedType valueType = getValueOperandType();
3571 ArrayRef<int64_t> valueShape = valueType.getShape();
3572 int64_t alphaH = valueShape[getValueAlphaHDim()];
3573 int64_t alphaW = valueShape[getValueAlphaWDim()];
3574 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3575 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3576
3577 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3578 offsets[getValueTileWDim()], offsets[getValueNDim()],
3579 offsets[getValueFDim()]});
3580 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3581 sizes[getValueTileWDim()], sizes[getValueNDim()],
3582 sizes[getValueFDim()]});
3583 int64_t valueRank = getValueOperandRank();
3584 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3585 auto valueSlice = tensor::ExtractSliceOp::create(
3586 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3587 tiledOperands.emplace_back(valueSlice);
3588
3589 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3590 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3591 resultSizes)))
3592 return failure();
3593
3594 int64_t outputRank = getOutputOperandRank();
3595 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3596 auto outputSlice = tensor::ExtractSliceOp::create(
3597 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3598 tiledOperands.emplace_back(outputSlice);
3599
3600 SmallVector<Type> resultTypes;
3601 resultTypes.push_back(tiledOperands[1].getType());
3602 Operation *tiledOp =
3603 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3604
3605 return TilingResult{
3606 {tiledOp},
3607 SmallVector<Value>(tiledOp->getResults()),
3608 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3609}
3610
3611//===----------------------------------------------------------------------===//
3612// LinalgDialect
3613// TODO: Merge with the LinalgDialect block at the bottom
3614//===----------------------------------------------------------------------===//
3615
3616// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3617static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3618 auto explicitRange = subMap.getResults();
3619 auto defaultRange = fullMap.getResults();
3620 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3621 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3622 llvm::set_union(explicitSet, defaultSet);
3623 return explicitSet == defaultSet;
3624}
3625
3626/// Check if the user defined map is valid broadcast map. Here broadcast
3627/// indexing maps are defined in context of corresponding default indexing maps
3628/// for the given Op. This way the check becomes very simple i.e just check the
3629/// number of result dims.
3630/// Returns true if the explictMap is broadcasted with respect to the
3631/// defaultMap.
3632static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3633 return explictMap.getNumResults() < defaultMap.getNumResults();
3634}
3635
3636/// Verifies the broadcast and transpose semantic sepecified by the explicit
3637/// indexing map for the MatmulOp \p op for each operand specified by \p
3638/// opIndex.
3639static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3640 unsigned opIndex) {
3641 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3642 SmallVector<AffineMap, 3> defaultIndexingMaps =
3643 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3644
3645 auto opIndexingMap = opIndexingMaps[opIndex];
3646 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3647 // Check general validity of indexing map results.
3648 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3649 return matmulOp->emitOpError()
3650 << "Unexpected dim expression in map result.";
3651
3652 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3653 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3654 return matmulOp->emitOpError()
3655 << "Invalid broadcast requested, should be (d2).";
3656 }
3657 return success();
3658 }
3659 return success();
3660}
3661
3662// Check general validity of input indexing map of
3663// BatchMatmulOp/BatchReduceMatmulOp.
3664template <typename OpTy>
3665static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3666 AffineMap opIndexingMap,
3667 AffineMap defaultIndexingMap, bool isLHS) {
3668 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3669 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3670 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3671 // Check the result dims are valid.
3672 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3673 return batchVariantMatmulOp->emitOpError()
3674 << "Unexpected result dim expression (outside the set of default "
3675 "result dims).";
3676
3677 // Check for valid number of result dims of input maps.
3678 if (opIndexingMap.getNumResults() > 3)
3679 return batchVariantMatmulOp->emitOpError()
3680 << "no. of result dim expressions exceeds 3.";
3681
3682 auto hasValidBatchDim = [](AffineMap map) {
3683 AffineExpr batchDim = map.getResult(0);
3684 return batchDim.isFunctionOfDim(0);
3685 };
3686
3687 // Check if the requested broadcast is valid.
3688 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3689 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3690 return batchVariantMatmulOp->emitOpError()
3691 << "Invalid broadcast requested.";
3692 } else if (!hasValidBatchDim(opIndexingMap)) {
3693 return batchVariantMatmulOp->emitOpError()
3694 << "Invalid batch dimension expression.";
3695 }
3696 return success();
3697}
3698
3699/// This function checks if the given AffineMap for the output of a
3700/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3701/// dimensions and if the output map result dimensions are valid.
3702template <typename OpTy>
3703static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3704 AffineMap opIndexingMap) {
3705 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3706 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3707 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3708 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3709 opIndexingMap.getNumResults() != 3) {
3710
3711 return batchVariantMatmulOp->emitOpError()
3712 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3713 << ").";
3714 }
3715 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3716 opIndexingMap.getNumResults() != 2) {
3717 return batchVariantMatmulOp->emitOpError()
3718 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3719 << ").";
3720 }
3721
3722 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3723 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3724 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3725 outputMap.getResult(1).isFunctionOfDim(1) &&
3726 outputMap.getResult(2).isFunctionOfDim(2)
3727 : outputMap.getResult(0).isFunctionOfDim(1) &&
3728 outputMap.getResult(1).isFunctionOfDim(2);
3729 };
3730
3731 if (!areValidOutputResultDim(opIndexingMap)) {
3732 return batchVariantMatmulOp->emitOpError()
3733 << "Invalid output map result dimension.";
3734 }
3735
3736 return success();
3737}
3738
3739/// Verifies the broadcast and transpose semantic specified by the explicit
3740/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3741/// specified by opIndex.
3742template <typename OpTy>
3743static LogicalResult
3745 unsigned opIndex) {
3746 SmallVector<AffineMap, 3> opIndexingMaps =
3747 batchVariantMatmulOp.getIndexingMapsArray();
3748 SmallVector<AffineMap, 3> defaultIndexingMaps =
3749 batchVariantMatmulOp.getDefaultIndexingMaps(
3750 batchVariantMatmulOp->getContext());
3751
3752 if (opIndexingMaps.size() != 3)
3753 return batchVariantMatmulOp->emitOpError()
3754 << "Indexing_map attribute must have 3 affine maps.";
3755
3756 auto opIndexingMap = opIndexingMaps[opIndex];
3757 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3758
3759 if (opIndex == 2 &&
3760 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3761 return failure();
3762
3763 if (opIndex != 2 &&
3764 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3765 defaultIndexingMap, opIndex == 0)))
3766 return failure();
3767
3768 return success();
3769}
3770
3771namespace mlir {
3772namespace linalg {
3773
3774std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3775 if (m == 2 && r == 3)
3776 return WinogradConv2DFmr::F_2_3;
3777 if (m == 4 && r == 3)
3778 return WinogradConv2DFmr::F_4_3;
3779 if (m == 2 && r == 5)
3780 return WinogradConv2DFmr::F_2_5;
3781 return std::nullopt;
3782}
3783
3784std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3785 switch (fmr) {
3786 case WinogradConv2DFmr::F_2_3:
3787 return {2, 3};
3788 case WinogradConv2DFmr::F_4_3:
3789 return {4, 3};
3790 case WinogradConv2DFmr::F_2_5:
3791 return {2, 5};
3792 }
3793 llvm_unreachable("Unkown WinogradConv2DFmr");
3794}
3795
3796//===----------------------------------------------------------------------===//
3797// MatMulOp
3798//===----------------------------------------------------------------------===//
3799
3800static FailureOr<SmallVector<SmallVector<int64_t>>>
3803 for (auto map : maps) {
3804 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3805 if (!attr)
3806 return failure();
3808 for (auto result : attr.getAffineMap().getResults()) {
3809 auto dim = dyn_cast<AffineDimExpr>(result);
3810 if (!dim)
3811 return failure();
3812 pos.push_back(dim.getPosition());
3813 }
3814 positions.push_back(pos);
3815 }
3816 return positions;
3817}
3818
3819/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3820SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3821 AffineExpr d0, d1, d2;
3822 SmallVector<AffineMap> indexingMaps;
3823 bindDims(context, d0, d1, d2);
3824 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3825 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3826 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3827 return indexingMaps;
3828}
3829
3830bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3831 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3832 if (!maps)
3833 return false;
3834 if (maps.size() != 3)
3835 return false;
3836 auto positions = getAffineResultPositions(maps);
3837 if (failed(positions))
3838 return false;
3839 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3840 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3841 (*positions)[2] == SmallVector<int64_t>{0, 1};
3842}
3843
3844SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3845 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3846 utils::IteratorType::parallel,
3847 utils::IteratorType::reduction};
3848}
3849
3850unsigned MatmulOp::getNumRegionArgs() { return 3; }
3851
3852std::string MatmulOp::getLibraryCallName() {
3853 return generateLibraryCallName(getOperation());
3854}
3855
3856bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3857
3858/// Check if the op has broadcast and/or transpose semantic. Returns true if
3859/// the user defined indexing maps are not equal to default map.
3860bool MatmulOp::hasUserDefinedMaps() {
3861 SmallVector<AffineMap, 3> defaultMaps =
3862 getDefaultIndexingMaps(this->getContext());
3863 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3864 return defaultMaps != explicitMaps;
3865}
3866
3867/// Implements the block region builder for the MatmulOp. This is called by
3868/// 'fillStructuredOpRegion'.
3869void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3870 ArrayRef<NamedAttribute> attrs,
3871 function_ref<InFlightDiagnostic()> emitError) {
3872 if (emitError && block.getNumArguments() != 3) {
3873 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3874 << block.getNumArguments();
3875 return;
3876 }
3877 assert(block.getNumArguments() == 3 &&
3878 "MatmulOp regionBuilder expects 3 args");
3879 RegionBuilderHelper helper(b, block);
3880 SmallVector<Value> yields;
3881
3882 TypeFn castVal = TypeFn::cast_signed;
3883 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3884 return attr.getName() == "cast";
3885 });
3886 if (castIter != attrs.end()) {
3887 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3888 castVal = attr.getValue();
3889 }
3890
3891 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3892 block.getArgument(0));
3893 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3894 block.getArgument(1));
3895 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3896 if (!value1 || !value2 || !value3)
3897 return;
3898 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3899 value3, emitError);
3900 if (!value4)
3901 return;
3902 yields.push_back(value4);
3903 helper.yieldOutputs(yields);
3904}
3905
3906/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3907/// broadcast map must include K dimension.
3908/// TODO: Strict inclusion of K dimension in the broadcast map is not
3909/// necessary for both input matrices simultaneously. We can relax this
3910/// condition to have K dimension for one input matrix map and infer the K
3911/// dimension for other input matrix map from the one already having K
3912/// dimension.
3913bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3914 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3915 AffineExpr expr = bcastMap.getResult(0);
3916 // Invalid map if the common dimension of matmul not found.
3917 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3918}
3919
3920static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3921 if (parser.parseOptionalKeyword("indexing_maps"))
3922 return ArrayAttr{
3923 nullptr}; // Success in case indexing_maps was not provided.
3924
3925 ArrayAttr arrayAttr;
3926 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3927 return failure();
3928
3929 if (llvm::any_of(arrayAttr,
3930 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3931 return parser.emitError(parser.getCurrentLocation())
3932 << "element of indexing_maps array is not an affine_map";
3933
3934 return arrayAttr;
3935}
3936
3937ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3938 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3939 if (failed(indexingMapsAttr))
3940 return failure();
3941
3942 if (*indexingMapsAttr == nullptr) {
3943 auto indexingMapAttrs = llvm::map_to_vector(
3944 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3945 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3946 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3947 }
3948
3949 result.addAttribute("indexing_maps", *indexingMapsAttr);
3950 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3951 MatmulOp::getRegionBuilder());
3952}
3953
3954void MatmulOp::print(OpAsmPrinter &p) {
3955 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3956 MatmulOp::getDefaultIndexingMaps(getContext()),
3957 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3958 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3959 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3960
3961 std::array<StringRef, 3> elidedAttrs = {
3962 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3963 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3964 elidedAttrs);
3965}
3966
3967/// Verify the user defined indexing maps.
3968LogicalResult MatmulOp::verify() {
3969 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3970 if (!hasUserDefinedMaps())
3971 return success();
3972
3973 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3974 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3975 return failure();
3976 }
3977 return success();
3978}
3979
3980LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3981 return memref::foldMemRefCast(*this);
3982}
3983
3984void MatmulOp::getEffects(
3985 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3986 &effects) {
3987 if (hasPureTensorSemantics())
3988 return;
3989 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3990}
3991
3992Speculation::Speculatability MatmulOp::getSpeculatability() {
3993 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3994}
3995
3996SmallVector<AffineMap>
3997MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3998 AffineExpr d0, d1, d2;
3999 MLIRContext *context = builder.getContext();
4000 bindDims(context, d0, d1, d2);
4001 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
4002 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
4003 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4004 return {mapLHS, mapRHS, mapOut};
4005}
4006
4008 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4009 if (!maps)
4010 return false;
4011 if (maps.size() != 3)
4012 return false;
4013 auto positions = getAffineResultPositions(maps);
4014 if (failed(positions))
4015 return false;
4016 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4017 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4018 (*positions)[2] == SmallVector<int64_t>{0, 1};
4019}
4020
4023 ValueRange inputs, ValueRange outputs,
4024 ArrayRef<NamedAttribute> attributes) {
4025 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4026 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4027}
4028
4031 ValueRange inputs, ValueRange outputs,
4032 ArrayRef<NamedAttribute> attributes) {
4033 OperationState state(location, getOperationName());
4034 build(builder, state, inputs, outputs, attributes);
4035 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4036 assert(res && "builder didn't return the right type");
4037 return res;
4038}
4039
4042 TypeRange resultTensorTypes,
4043 ValueRange inputs, ValueRange outputs,
4044 ArrayRef<NamedAttribute> attributes) {
4045 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4046 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4047}
4048
4051 TypeRange resultTensorTypes, ValueRange inputs,
4052 ValueRange outputs,
4053 ArrayRef<NamedAttribute> attributes) {
4054 OperationState state(location, getOperationName());
4055 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4056 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4057 assert(res && "builder didn't return the right type");
4058 return res;
4059}
4060
4063 TypeRange resultTensorTypes,
4064 ValueRange inputs, ValueRange outputs,
4065 Attribute cast,
4066 ArrayRef<NamedAttribute> attributes) {
4067 result.addAttribute("cast", cast);
4068 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4069 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4070}
4071
4074 TypeRange resultTensorTypes, ValueRange inputs,
4075 ValueRange outputs, Attribute cast,
4076 ArrayRef<NamedAttribute> attributes) {
4077 OperationState state(location, getOperationName());
4078 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4079 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4080 assert(res && "builder didn't return the right type");
4081 return res;
4082}
4083
4085 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4087 op->getAttr("indexing_maps"));
4088}
4089
4091MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4092 AffineExpr d0, d1, d2;
4093 MLIRContext *context = builder.getContext();
4094 bindDims(context, d0, d1, d2);
4095 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4096 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4097 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4098 return {mapLHS, mapRHS, mapOut};
4099}
4100
4102 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4103 if (!maps)
4104 return false;
4105 if (maps.size() != 3)
4106 return false;
4107 auto positions = getAffineResultPositions(maps);
4108 if (failed(positions))
4109 return false;
4110 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4111 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4112 (*positions)[2] == SmallVector<int64_t>{0, 1};
4113}
4114
4117 ValueRange inputs, ValueRange outputs,
4118 ArrayRef<NamedAttribute> attributes) {
4119 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4120 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4121}
4122
4125 ValueRange inputs, ValueRange outputs,
4126 ArrayRef<NamedAttribute> attributes) {
4127 OperationState state(location, getOperationName());
4128 build(builder, state, inputs, outputs, attributes);
4129 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4130 assert(res && "builder didn't return the right type");
4131 return res;
4132}
4133
4136 TypeRange resultTensorTypes,
4137 ValueRange inputs, ValueRange outputs,
4138 ArrayRef<NamedAttribute> attributes) {
4139 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4140 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4141}
4142
4145 TypeRange resultTensorTypes, ValueRange inputs,
4146 ValueRange outputs,
4147 ArrayRef<NamedAttribute> attributes) {
4148 OperationState state(location, getOperationName());
4149 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4150 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4151 assert(res && "builder didn't return the right type");
4152 return res;
4153}
4154
4157 TypeRange resultTensorTypes,
4158 ValueRange inputs, ValueRange outputs,
4159 Attribute cast,
4160 ArrayRef<NamedAttribute> attributes) {
4161 result.addAttribute("cast", cast);
4162 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4163 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4164}
4165
4168 TypeRange resultTensorTypes, ValueRange inputs,
4169 ValueRange outputs, Attribute cast,
4170 ArrayRef<NamedAttribute> attributes) {
4171 OperationState state(location, getOperationName());
4172 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4173 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4174 assert(res && "builder didn't return the right type");
4175 return res;
4176}
4177
4179 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4181 op->getAttr("indexing_maps"));
4182}
4183
4185BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4186 AffineExpr d0, d1, d2, d3;
4187 MLIRContext *context = builder.getContext();
4188 bindDims(context, d0, d1, d2, d3);
4189 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4190 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4191 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4192 return {mapLHS, mapRHS, mapOut};
4193}
4194
4196 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4197 if (!maps)
4198 return false;
4199 if (maps.size() != 3)
4200 return false;
4201 auto positions = getAffineResultPositions(maps);
4202 if (failed(positions))
4203 return false;
4204 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4205 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4206 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4207}
4208
4210 OpBuilder &builder, OperationState &result, ValueRange inputs,
4211 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4212 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4213 BatchMatmulOp::getRegionBuilder(),
4214 getDefaultIndexingMaps(builder));
4215}
4216
4219 ValueRange inputs, ValueRange outputs,
4220 ArrayRef<NamedAttribute> attributes) {
4221 OperationState state(location, getOperationName());
4222 build(builder, state, inputs, outputs, attributes);
4223 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4224 assert(res && "builder didn't return the right type");
4225 return res;
4226}
4227
4229 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4230 ValueRange inputs, ValueRange outputs,
4231 ArrayRef<NamedAttribute> attributes) {
4232 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4233 BatchMatmulOp::getRegionBuilder(),
4234 getDefaultIndexingMaps(builder));
4235}
4236
4239 TypeRange resultTensorTypes, ValueRange inputs,
4240 ValueRange outputs,
4241 ArrayRef<NamedAttribute> attributes) {
4242 OperationState state(location, getOperationName());
4243 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4244 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4245 assert(res && "builder didn't return the right type");
4246 return res;
4247}
4248
4250 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4251 ValueRange inputs, ValueRange outputs, Attribute cast,
4252 ArrayRef<NamedAttribute> attributes) {
4253 result.addAttribute("cast", cast);
4254 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4255 BatchMatmulOp::getRegionBuilder(),
4256 getDefaultIndexingMaps(builder));
4257}
4258
4261 TypeRange resultTensorTypes, ValueRange inputs,
4262 ValueRange outputs, Attribute cast,
4263 ArrayRef<NamedAttribute> attributes) {
4264 OperationState state(location, getOperationName());
4265 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4266 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4267 assert(res && "builder didn't return the right type");
4268 return res;
4269}
4270
4272 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4274 op->getAttr("indexing_maps"));
4275}
4276
4278BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4279 AffineExpr d0, d1, d2, d3;
4280 MLIRContext *context = builder.getContext();
4281 bindDims(context, d0, d1, d2, d3);
4282 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4283 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4284 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4285 return {mapLHS, mapRHS, mapOut};
4286}
4287
4289 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4290 if (!maps)
4291 return false;
4292 if (maps.size() != 3)
4293 return false;
4294 auto positions = getAffineResultPositions(maps);
4295 if (failed(positions))
4296 return false;
4297 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4298 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4299 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4300}
4301
4303 OpBuilder &builder, OperationState &result, ValueRange inputs,
4304 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4305 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4306 BatchMatmulOp::getRegionBuilder(),
4307 getDefaultIndexingMaps(builder));
4308}
4309
4312 ValueRange inputs, ValueRange outputs,
4313 ArrayRef<NamedAttribute> attributes) {
4314 OperationState state(location, getOperationName());
4315 build(builder, state, inputs, outputs, attributes);
4316 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4317 assert(res && "builder didn't return the right type");
4318 return res;
4319}
4320
4322 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4323 ValueRange inputs, ValueRange outputs,
4324 ArrayRef<NamedAttribute> attributes) {
4325 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4326 BatchMatmulOp::getRegionBuilder(),
4327 getDefaultIndexingMaps(builder));
4328}
4329
4332 TypeRange resultTensorTypes, ValueRange inputs,
4333 ValueRange outputs,
4334 ArrayRef<NamedAttribute> attributes) {
4335 OperationState state(location, getOperationName());
4336 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4337 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4338 assert(res && "builder didn't return the right type");
4339 return res;
4340}
4341
4343 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4344 ValueRange inputs, ValueRange outputs, Attribute cast,
4345 ArrayRef<NamedAttribute> attributes) {
4346 result.addAttribute("cast", cast);
4347 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4348 BatchMatmulOp::getRegionBuilder(),
4349 getDefaultIndexingMaps(builder));
4350}
4351
4354 TypeRange resultTensorTypes, ValueRange inputs,
4355 ValueRange outputs, Attribute cast,
4356 ArrayRef<NamedAttribute> attributes) {
4357 OperationState state(location, getOperationName());
4358 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4359 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4360 assert(res && "builder didn't return the right type");
4361 return res;
4362}
4363
4365 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4367 op->getAttr("indexing_maps"));
4368}
4369
4370//===----------------------------------------------------------------------===//
4371// ContractOp
4372//===----------------------------------------------------------------------===//
4373
4374SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4375 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4376 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4377 // domains are all the same, and each implements a projected permutation.
4378 // Each iteration space dim must occur for at least one operand and either
4379 // takes part in a contraction/reduction or else has parallel iteration type.
4380 // We have that a dim is a contraction/reduction dim if and only if the dim
4381 // occurs for the output operand. We use this fact for fast inference:
4382 // NB: In case we allow dims to occur solely for one input, the above still
4383 // holds: per the einsum semantics, these are reduction dims as well.
4384 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4385 for (auto result : outAffineMap.getResults()) {
4386 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4387 assert(dimExpr && "affine_map is a projected permutation");
4388 dimsInOutput[dimExpr.getPosition()] = true;
4389 }
4390
4392 for (auto dimOccursInOutput : dimsInOutput)
4393 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4394 : utils::IteratorType::reduction);
4395
4396 return iteratorTypes;
4397}
4398
4399unsigned ContractOp::getNumRegionArgs() { return 3; }
4400
4401/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4402void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4403 ArrayRef<NamedAttribute> attrs,
4404 function_ref<InFlightDiagnostic()> emitError) {
4405 if (emitError && block.getNumArguments() != 3) {
4406 emitError() << "ContractOp regionBuilder expects 3 args, got "
4407 << block.getNumArguments();
4408 return;
4409 }
4410 assert(block.getNumArguments() == 3 &&
4411 "ContractOp regionBuilder expects 3 args");
4412 RegionBuilderHelper helper(b, block);
4413
4414 TypeFn castSignedness = TypeFn::cast_signed;
4415 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4416 return attr.getName() == "cast";
4417 });
4418 if (castIter != attrs.end()) {
4419 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4420 castSignedness = attr.getValue();
4421 }
4422
4423 // TODO: Support fields with operators besides mult & add.
4424 Type outType = block.getArgument(2).getType();
4425 Value lhsAtOutType =
4426 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4427 Value rhsAtOutType =
4428 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4429 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4430 rhsAtOutType, emitError);
4431 if (!productAtOutType)
4432 return;
4433 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4434 productAtOutType, emitError);
4435 if (!result)
4436 return;
4437 helper.yieldOutputs({result});
4438}
4439
4440ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4441 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4442 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4443 return parser.emitError(parser.getCurrentLocation(),
4444 "expected 'indexing_maps' attribute");
4445 result.addAttribute("indexing_maps", *indexingMapsAttr);
4446
4447 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4448 regionBuilder);
4449}
4450
4451void ContractOp::print(OpAsmPrinter &p) {
4452 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4454 p, getOperation(), getInputs(), getOutputs(),
4455 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4456}
4457
4458LogicalResult ContractOp::verify() {
4459 int iterationSpaceDims = -1;
4460 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4461 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4462 // access an input operand (so occurrence count can be at most 2) and
4463 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4464 SmallVector<size_t> inOccurrences;
4465 SmallVector<size_t> outOccurrences;
4466
4467 // A helper so that for each operand's affine_map and type we check that ...
4468 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4469 bool isInput) -> LogicalResult {
4470 // ... the affine_map is a projected permutation;
4471 if (!affineMap.isProjectedPermutation())
4472 return emitError("provided affine_map is not a projected permutation");
4473
4474 // ... the rank of the affine_map's results and corresponding type match;
4475 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4476 if (affineMap.getNumResults() != shapedType.getRank())
4477 return emitError("ranks of shaped operand and results of corresponding "
4478 "affine_map differ");
4479 } else if (affineMap.getNumResults() != 0) {
4480 return emitError("affine_map specifies shaped access while operand has "
4481 "non-shaped type");
4482 }
4483
4484 // ... the rank of the affine_map's domain is the same as those seen prior;
4485 if (iterationSpaceDims == -1) {
4486 iterationSpaceDims = affineMap.getNumDims();
4487 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4488 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4489 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4490 return emitError("iteration spaces of provided affine_maps differ");
4491 }
4492
4493 // ... update counts of dims used to access either an input or the output.
4494 for (AffineExpr affineExpr : affineMap.getResults()) {
4495 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4496 if (!affineDimExpr)
4497 llvm_unreachable("affine_map is a projected permutation");
4498
4499 if (isInput)
4500 inOccurrences[affineDimExpr.getPosition()] += 1;
4501 else
4502 outOccurrences[affineDimExpr.getPosition()] += 1;
4503 }
4504
4505 return success();
4506 };
4507
4508 for (auto &&[affineMap, operandType, isInput] :
4509 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4510 SmallVector<bool>{true, true, false})) {
4511 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4512 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4513 }
4514
4515 bool hasContractingDim = false;
4516 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4517 size_t inOccCount = inOccurrences[dimIndex];
4518 size_t outOccCount = outOccurrences[dimIndex];
4519
4520 // We have a contracting dim if and only if ...
4521 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4522
4523 if (inOccCount == 0 && outOccCount == 0)
4524 return emitError() << "iteration space dim at index " << dimIndex
4525 << " not used to access any operand";
4526
4527 // NB: We disallow a dim which occurs for only one input operand and not
4528 // for the output. In terms of einsum semantics such dims have a
4529 // sensible meaning - namely an additional reduction per each such dim.
4530 // By contrast, the ContractionOpInterface does not know about this
4531 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4532 // while vector.contract's verifier accepts dims of this kind many of
4533 // its lowerings give up on encountering these dims.
4534 // TODO: Remove following once we have comprehensive support for input-only
4535 // reduction dims, at both the linalg- and vector-dialect levels.
4536 if (inOccCount == 1 && outOccCount != 1)
4537 return emitError()
4538 << "iteration space dim at index " << dimIndex
4539 << " is neither a contracting dim nor of parallel iteration type";
4540 }
4541
4542 if (!hasContractingDim)
4543 return emitError("'indexing_maps' do not specify a contracting dimension");
4544
4545 return success();
4546}
4547
4548LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4549 return memref::foldMemRefCast(*this);
4550}
4551
4552void ContractOp::getEffects(
4553 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4554 &effects) {
4555 if (hasPureTensorSemantics())
4556 return;
4557 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4558}
4559
4560Speculation::Speculatability ContractOp::getSpeculatability() {
4561 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4562}
4563
4564//===----------------------------------------------------------------------===//
4565// Implementation of BatchMatmulOp
4566//===----------------------------------------------------------------------===//
4567SmallVector<AffineMap>
4568BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4569 AffineExpr d0, d1, d2, d3;
4570 SmallVector<AffineMap> indexingMaps;
4571 bindDims(context, d0, d1, d2, d3);
4572 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4573 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4574 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4575 return indexingMaps;
4576}
4577
4578bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4579 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4580 if (!maps)
4581 return false;
4582 if (maps.size() != 3)
4583 return false;
4584 auto positions = getAffineResultPositions(maps);
4585 if (failed(positions))
4586 return false;
4587 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4588 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4589 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4590}
4591
4592SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4593 return SmallVector<utils::IteratorType>{
4594 utils::IteratorType::parallel, utils::IteratorType::parallel,
4595 utils::IteratorType::parallel, utils::IteratorType::reduction};
4596}
4597
4598unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4599
4600std::string BatchMatmulOp::getLibraryCallName() {
4601 return generateLibraryCallName(getOperation());
4602}
4603
4604/// Check if the op has broadcast and/or transpose semantic. Returns true if
4605/// the user defined indexing maps are not equal to default map.
4606bool BatchMatmulOp::hasUserDefinedMaps() {
4607 SmallVector<AffineMap, 3> defaultMaps =
4608 getDefaultIndexingMaps(this->getContext());
4609 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4610 return defaultMaps != explicitMaps;
4611}
4612
4613/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4614/// broadcast map must include K dimension.
4615/// TODO: Strict inclusion of K dimension in the broadcast map is not
4616/// necessary for both input matrices simultaneously. We can relax this
4617/// condition to have K dimension for one input matrix map and infer the K
4618/// dimension for other input matrix map from the one already having K
4619/// dimension.
4620bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4621 assert(bcastMap.getNumResults() < 3 &&
4622 "Expected less than 3 result dim expr.");
4623 bool isValid = false;
4624 enum Indices { batchPos, mPos, nPos, kPos };
4625 if (bcastMap.getNumResults() == 1) {
4626 AffineExpr expr = bcastMap.getResult(0);
4627 isValid = expr.isFunctionOfDim(kPos);
4628 } else if (bcastMap.getNumResults() == 2) {
4629 AffineExpr expr0 = bcastMap.getResult(0);
4630 AffineExpr expr1 = bcastMap.getResult(1);
4631 isValid =
4632 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4633 expr0.isFunctionOfDim(mPos)) &&
4634 expr1.isFunctionOfDim(kPos))
4635 : ((expr0.isFunctionOfDim(batchPos) &&
4636 expr1.isFunctionOfDim(kPos)) ||
4637 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4638 }
4639 return isValid;
4640}
4641
4642void BatchMatmulOp::regionBuilder(
4643 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4644 function_ref<InFlightDiagnostic()> emitError) {
4645 if (emitError && block.getNumArguments() != 3) {
4646 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4647 << block.getNumArguments();
4648 return;
4649 }
4650 assert(block.getNumArguments() == 3 &&
4651 "BatchMatmulOp regionBuilder expects 3 args");
4652 RegionBuilderHelper helper(b, block);
4653 SmallVector<Value> yields;
4654
4655 TypeFn castVal = TypeFn::cast_signed;
4656 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4657 return attr.getName() == "cast";
4658 });
4659 if (castIter != attrs.end()) {
4660 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4661 castVal = attr.getValue();
4662 }
4663
4664 auto toType = block.getArgument(2).getType();
4665 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4666 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4667 Value mulVal =
4668 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4669 if (!castValA || !castValB || !mulVal)
4670 return;
4671 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4672 mulVal, emitError);
4673 if (!addVal)
4674 return;
4675 yields.push_back(addVal);
4676 helper.yieldOutputs(yields);
4677}
4678
4679ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4680 SmallVector<Attribute, 3> indexingMapsAttr;
4681 Attribute mapAttr;
4682 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4683 if (parser.parseEqual())
4684 return failure();
4685
4686 if (parser.parseLSquare())
4687 return failure();
4688
4689 do {
4690 if (parser.parseAttribute(mapAttr))
4691 return failure();
4692 if (!isa<AffineMapAttr>(mapAttr)) {
4693 return parser.emitError(parser.getCurrentLocation(),
4694 "expected affine map attribute");
4695 }
4696 indexingMapsAttr.push_back(mapAttr);
4697
4698 if (parser.parseOptionalComma())
4699 break;
4700 } while (true);
4701
4702 if (parser.parseRSquare())
4703 return failure();
4704 }
4705 // Initialize indexingMaps, if not supplied explicitly.
4706 if (indexingMapsAttr.empty()) {
4707 indexingMapsAttr = llvm::map_to_vector(
4708 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4709 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4710 }
4711 result.addAttribute("indexing_maps",
4712 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4713
4714 return ::parseNamedStructuredOp(parser, result,
4715 BatchMatmulOp::getNumRegionArgs(),
4716 BatchMatmulOp::getRegionBuilder());
4717}
4718
4719void BatchMatmulOp::print(OpAsmPrinter &p) {
4720 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4721 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4722 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4723 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4724 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4725
4726 std::array<StringRef, 3> elidedAttrs = {
4727 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4728 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4729 elidedAttrs);
4730}
4731
4732/// Verify the user defined indexing maps.
4733LogicalResult BatchMatmulOp::verify() {
4734 // Verification of pure batch_matmul is handled by
4735 // verifyStructuredOpInterface().
4736 if (!hasUserDefinedMaps())
4737 return success();
4738
4739 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4741 return failure();
4742 }
4743 return success();
4744}
4745
4746LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4747 SmallVectorImpl<OpFoldResult> &) {
4748 return memref::foldMemRefCast(*this);
4749}
4750
4751void BatchMatmulOp::getEffects(
4752 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4753 &effects) {
4754 if (hasPureTensorSemantics())
4755 return;
4756 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4757}
4758
4759Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4760 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4761}
4762
4763//===----------------------------------------------------------------------===//
4764// ElementwiseOp
4765//===----------------------------------------------------------------------===//
4766//
4767namespace {
4768struct ArityGroupAndKind {
4769 // The enum class {Unary, Binary, Ternary, ..}
4770 ElementwiseArityGroup arityGroup;
4771
4772 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4773 union Kind {
4774 UnaryFn unaryFn;
4775 BinaryFn binaryFn;
4776 TernaryFn ternaryFn;
4777 } kind;
4778};
4779
4780unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4781 return static_cast<unsigned>(arityGroup);
4782}
4783} // namespace
4784
4785static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4786 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4787 constexpr int lastBinary =
4788 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4789 constexpr int lastTernary =
4790 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4791
4792 int val = static_cast<int>(kind);
4793 ArityGroupAndKind result;
4794
4795 if (val < lastUnary) {
4796 result.arityGroup = ElementwiseArityGroup::Unary;
4797 result.kind.unaryFn = static_cast<UnaryFn>(val);
4798 return result;
4799 }
4800 if (val < lastBinary) {
4801 result.arityGroup = ElementwiseArityGroup::Binary;
4802 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4803 return result;
4804 }
4805 if (val >= lastTernary) {
4806 llvm_unreachable("unhandled ElementwiseFn");
4807 }
4808 result.arityGroup = ElementwiseArityGroup::Ternary;
4809 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4810 return result;
4811}
4812
4813SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4814 auto rank = getResultRank();
4815 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4816}
4817
4819ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4820 MLIRContext *context) {
4821 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4822 return SmallVector<AffineMap>(numMaps, map);
4823}
4824
4825ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4826 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4827 Attribute attr;
4828 mlir::linalg::ElementwiseKind elemwiseKindVal;
4829 if (parser.parseKeyword("kind") || parser.parseEqual())
4830 return failure();
4831
4832 if (succeeded(parser.parseAttribute(attr))) {
4833 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4834 if (!elemwiseKindAttr)
4835 return parser.emitError(parser.getCurrentLocation(),
4836 "expected ElementwiseKind attribute");
4837 elemwiseKindVal = elemwiseKindAttr.getValue();
4838 } else {
4839 return parser.emitError(parser.getCurrentLocation(),
4840 "expected operation 'kind' attribute");
4841 }
4842 result.addAttribute(
4843 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4844
4845 // Parse optional `indexing_maps`
4846 SmallVector<Attribute, 3> indexingMapsAttr;
4847 Attribute mapAttr;
4848 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4849 if (parser.parseEqual())
4850 return failure();
4851 if (parser.parseLSquare())
4852 return failure();
4853 do {
4854 if (parser.parseAttribute(mapAttr))
4855 return failure();
4856 if (!isa<AffineMapAttr>(mapAttr))
4857 return parser.emitError(parser.getCurrentLocation(),
4858 "expected affine map attribute");
4859 indexingMapsAttr.push_back(mapAttr);
4860 if (parser.parseOptionalComma())
4861 break;
4862 } while (true);
4863 if (parser.parseRSquare())
4864 return failure();
4865 }
4866 // At this stage of parsing the only way to infer number of region
4867 // args is through op kind, as input output tensors are not parsed yet.
4868 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4869 int numRegionArgs =
4870 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4871 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4872 ElementwiseOp::getRegionBuilder())) {
4873 return parser.emitError(parser.getCurrentLocation(),
4874 "unable to parse elemwise op");
4875 }
4876
4877 // Initialize indexingMaps, if not supplied explicitly.
4878 if (indexingMapsAttr.empty()) {
4879 // We need to infer the numDims of the indexing maps from the output
4880 // type which is already parsed by now.
4881 auto resultType = result.operands[result.operands.size() - 1].getType();
4882 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4883 if (!shapedType)
4884 return parser.emitError(parser.getCurrentLocation(),
4885 "return type needs to be shaped type");
4886 auto numDims = shapedType.getRank();
4887 indexingMapsAttr = llvm::map_to_vector(
4888 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4889 parser.getContext()),
4890 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4891 }
4892
4893 result.addAttribute("indexing_maps",
4894 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4895 return success();
4896}
4897
4898void ElementwiseOp::print(OpAsmPrinter &p) {
4899 p << " kind=";
4900 p.printAttribute(getKindAttr());
4901 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4902 "indexing_maps"};
4903 unsigned arity =
4904 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4905 unsigned numDims = getResultRank();
4906
4907 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4908 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4909 getContext()),
4910 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4911
4912 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4913 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4914
4915 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4916 elidedAttrs);
4917}
4918
4919/// Implements the block region builder for the ElementwiseOp. This is called by
4920/// 'fillStructuredOpRegion'.
4921void ElementwiseOp::regionBuilder(
4922 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4923 function_ref<InFlightDiagnostic()> emitError) {
4924 ElementwiseKind elemwiseKind;
4925 for (auto attr : attrs) {
4926 if (attr.getName() == b.getStringAttr("kind")) {
4927 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4928 assert(kindAttr && "op kind attribute incorrectly set");
4929 elemwiseKind = kindAttr.getValue();
4930 break;
4931 }
4932 }
4933
4934 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4935 auto arityGroup = groupAndKind.arityGroup;
4936 auto kind = groupAndKind.kind;
4937 if (emitError && block.getNumArguments() !=
4938 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4939 emitError() << "Elementwise regionBuilder expects "
4940 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4941 << block.getNumArguments();
4942 return;
4943 }
4944 assert(block.getNumArguments() ==
4945 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4946 && "Elementwise regionBuilder number of block args mismatch");
4947
4948 RegionBuilderHelper helper(b, block);
4949 SmallVector<Value> yields;
4950 Value result;
4951
4952 if (arityGroup == ElementwiseArityGroup::Unary) {
4953 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4954
4955 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4956 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4957 block.getArgument(1));
4958
4959 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4960 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4961 block.getArgument(1), block.getArgument(2));
4962
4963 } else {
4964 assert(false && "found unhandled category in elemwise");
4965 }
4966
4967 yields.push_back(result);
4968 helper.yieldOutputs(yields);
4969}
4970
4971LogicalResult ElementwiseOp::fold(FoldAdaptor,
4972 SmallVectorImpl<OpFoldResult> &) {
4973 return memref::foldMemRefCast(*this);
4974}
4975
4976void ElementwiseOp::getEffects(
4977 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4978 &effects) {
4979 if (hasPureTensorSemantics())
4980 return;
4981 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4982}
4983
4984Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4985 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4986}
4987
4988//===----------------------------------------------------------------------===//
4989// PackOp/UnPackOp Common
4990//===----------------------------------------------------------------------===//
4991
4992template <typename OpTy, typename>
4993SmallVector<int64_t>
4995 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4996 ? packOrUnPack.getDestType()
4997 : packOrUnPack.getSourceType();
4998 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4999 ? packOrUnPack.getSourceType()
5000 : packOrUnPack.getDestType();
5002 packedType.getShape().take_front(unpackedType.getRank()));
5003 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5005 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
5006 }
5007 return result;
5008}
5013
5014// Given the (potentially) updated packed type, `newPackedTy`, generates an
5015// updated mixed-tile-sizes attribute. A tile size is updated only
5016// when:
5017// * a dim from newPackedTy is static, and
5018// * the corresponding size from mixedTiles is still dynamic.
5019// Otherwise, the original tile size is preserved.
5020// Note - packed-type-dim and mixed-tile-size should always match!
5023 SmallVector<OpFoldResult> mixedTiles) {
5024 SmallVector<OpFoldResult> newMixedTileSizes;
5025 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5026 .getShape()
5027 .take_back(mixedTiles.size()),
5028 mixedTiles)) {
5029 int64_t dimSize = std::get<0>(it);
5030 if (dimSize == ShapedType::kDynamic) {
5031 newMixedTileSizes.push_back(std::get<1>(it));
5032 continue;
5033 }
5034
5035 // If the current result dim is static, update the dynamic mixed-size
5036 // (provided the original value is dynamic).
5037 OpFoldResult tile = std::get<1>(it);
5038 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5039 // Already a constant
5040 newMixedTileSizes.push_back(tile);
5041 } else {
5042 assert(getConstantIntValue(tile).value() == dimSize &&
5043 "tile size and dim size don't match!");
5044 newMixedTileSizes.push_back(
5045 (rewriter.getIntegerAttr(rewriter.getIndexType(), dimSize)));
5046 }
5047 }
5048
5049 return newMixedTileSizes;
5050}
5051
5052template <typename OpTy>
5053static LogicalResult
5055 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5056 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5057 "applies to only pack or unpack operations");
5058 int64_t destRank = op.getDestRank();
5059 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5060 for (auto dim : llvm::seq<int64_t>(0, destRank))
5061 reifiedReturnShapes[0][dim] =
5062 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5063 return success();
5064}
5065
5066template <typename OpTy>
5068 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5069 "applies to only pack or unpack operations");
5070 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5071 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5072 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5073 assert(tiles.size() == dimsToTile.size() &&
5074 "tiles must match indices of dimension to block");
5075 // bind the dimension `i` with the tile factor.
5076 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5077 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5078 return dimAndTileMapping;
5079}
5080
5081template <typename OpTy>
5083 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5084 "applies to only pack or unpack operations");
5085 Builder builder(op);
5086 SmallVector<OpFoldResult> mixedInnerTiles;
5087 unsigned dynamicValIndex = 0;
5088 for (int64_t staticTile : op.getStaticInnerTiles()) {
5089 if (ShapedType::isStatic(staticTile))
5090 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5091 else
5092 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5093 }
5094 return mixedInnerTiles;
5095}
5096
5097template <typename OpTy>
5099 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5100 "applies to only pack or unpack operations");
5101 SmallVector<Value> dynamicTiles;
5102 SmallVector<int64_t> staticTiles;
5103 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5104 return staticTiles;
5105}
5106
5107/// Returns true if `dimsPos` is invalid. It is invalid when:
5108/// a) It contains duplicate.
5109/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5110/// c) The number of elements in `dimsPos` is > than `rank`.
5112 size_t rank) {
5113 size_t dimsPosSize = dimsPos.size();
5114 if (dimsPosSize > rank)
5115 return true;
5116 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5117 if (dimsPosSize != uniqued.size())
5118 return true;
5119 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5120 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5121 });
5122}
5123
5124template <typename OpTy>
5125static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5126 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5127 "applies to only pack or unpack operations");
5128 Operation *op = packOrUnPack.getOperation();
5129
5130 // Return true if we have a zero-value tile.
5131 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5132 return llvm::any_of(tiles, [](OpFoldResult tile) {
5133 return isa<Attribute>(tile) && isZeroInteger(tile);
5134 });
5135 };
5136
5137 // Verify that the source and destination are ranked types.
5138 if (!packOrUnPack.getSourceType().hasRank() ||
5139 !packOrUnPack.getDestType().hasRank())
5140 return op->emitError("expected both source and destination to have rank");
5141
5142 // Verify that the Operation does not have mixed tensor/buffer semantics.
5143 if (!packOrUnPack.hasPureBufferSemantics() &&
5144 !packOrUnPack.hasPureTensorSemantics())
5145 return op->emitError("mixing tensor and buffer semantics is not allowed");
5146 const unsigned numResults = packOrUnPack.getNumResults();
5147 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5148 return op->emitError("expected 1 result, got ") << numResults;
5149 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5150 return op->emitError("expected 0 results, got ") << numResults;
5151
5152 // Verify tiles. Do not allow zero tiles.
5153 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5154 if (hasZeros(mixedTiles))
5155 return op->emitError("invalid zero tile factor");
5156
5157 // Verify inner_dims_pos and outer_dims_perm.
5158 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5159 ? packOrUnPack.getSourceType()
5160 : packOrUnPack.getDestType();
5161 size_t unpackedRank = unpackedType.getRank();
5162 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5163 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5164 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5165 return op->emitError("invalid inner_dims_pos vector");
5166 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5167 return op->emitError("invalid outer_dims_perm vector");
5168 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5169 return op->emitError("outer_dims_perm must be a permutation or empty");
5170
5171 // Tiling factors must be less than or equal to the input rank for pack (or
5172 // output rank for unpack), and must match the number of `inner_dims_pos`.
5173 if (mixedTiles.size() > unpackedRank) {
5174 return op->emitError("tiling factors must be less than or equal to the "
5175 "input rank for pack or output rank for unpack");
5176 }
5177 if (mixedTiles.size() != innerDimsPos.size()) {
5178 return op->emitError(
5179 "tiling factors must equal the number of dimensions to tile");
5180 }
5181
5182 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5183 ? packOrUnPack.getDestType()
5184 : packOrUnPack.getSourceType();
5185 size_t packedRank = packedType.getRank();
5186 // Require output rank to match input rank + number of blocking factors.
5187 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5188 if (expectedPackedRank != packedRank) {
5189 return op->emitError(
5190 "packed rank != (unpacked rank + num tiling factors), got ")
5191 << packedRank << " != " << expectedPackedRank;
5192 }
5193
5194 // Verify result shape is greater than the minimum expected
5195 // by the pack operation, and that the output shape
5196 // represents full tiles.
5197 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5198 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5199 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5200 for (auto it : llvm::enumerate(llvm::zip(
5201 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5202 int64_t dimSize = std::get<0>(it.value());
5203 if (Attribute attr =
5204 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5205 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5206 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5207 if (dimSize != staticTileSize)
5208 return op->emitError(
5209 "mismatch in inner tile sizes specified and shaped of "
5210 "tiled dimension in the packed type at index ")
5211 << it.index() << ": got " << dimSize << " != " << staticTileSize;
5212 } else if (!ShapedType::isDynamic(dimSize)) {
5213 return op->emitError("mismatch in inner tile sizes specified at index ")
5214 << it.index() << ": got static shape " << dimSize
5215 << " but dynamic tile size";
5216 }
5217 }
5218 if (failed(
5219 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5220 auto elementType = unpackedType.getElementType();
5221 Type expectedType, actualType;
5222 if (packOrUnPack.hasPureTensorSemantics()) {
5223 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5224 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5225 } else {
5226 expectedType = MemRefType::get(expectedPackedShape, elementType);
5227 actualType = MemRefType::get(packedType.getShape(), elementType);
5228 }
5229 return op->emitError("expected ")
5230 << expectedType << " for the packed domain value, got "
5231 << actualType;
5232 }
5233 return success();
5234}
5235
5236namespace {
5237/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5238/// various permutations to the op.
5239// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5240// these. These may or may not become true foldings / canonicalizations
5241// depending on how aggressive we want to be in automatically folding
5242// transposes.
5243struct PackOrUnPackTransposeResult {
5244 SmallVector<int64_t> innerDimsPos;
5245 SmallVector<OpFoldResult> innerTiles;
5246 SmallVector<int64_t> outerDimsPerm;
5247};
5248} // namespace
5249
5250template <typename OpTy>
5251static PackOrUnPackTransposeResult
5253 ArrayRef<int64_t> innerPermutation,
5254 ArrayRef<int64_t> outerPermutation) {
5255 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5256 "applies to only pack or unpack operations");
5257 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5258 "some permutation must be non-empty");
5259 PackOrUnPackTransposeResult metadata;
5260 metadata.innerDimsPos =
5261 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5262 metadata.innerTiles =
5263 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5264 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5265 ? packOrUnPackOp.getSourceRank()
5266 : packOrUnPackOp.getDestRank();
5267 metadata.outerDimsPerm =
5268 packOrUnPackOp.getOuterDimsPerm().empty()
5269 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5270 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5271 if (!innerPermutation.empty()) {
5272 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5273 isPermutationVector(innerPermutation) &&
5274 "invalid inner permutation");
5275 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5276 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5277 }
5278 if (!outerPermutation.empty()) {
5279 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5280 isPermutationVector(outerPermutation) &&
5281 "invalid outer permutation");
5282 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5283 }
5284 return metadata;
5285}
5286
5287//===----------------------------------------------------------------------===//
5288// PackOp
5289//===----------------------------------------------------------------------===//
5290
5291void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5292 if (!getResults().empty())
5293 setNameFn(getResult(), "pack");
5294}
5295
5296ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5297 OpAsmParser::UnresolvedOperand source, dest;
5300 SmallVector<Type> paddingValueType;
5301 SmallVector<int64_t> staticTiles;
5302 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5303 Type sourceType, destType, resultType;
5304
5305 if (parser.parseOperand(source))
5306 return failure();
5307
5308 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5309 if (parser.parseLParen() ||
5310 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5311 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5312 parser.parseRParen())
5313 return failure();
5314 }
5315
5316 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5317 if (parser.parseEqual())
5318 return failure();
5319
5320 SmallVector<int64_t> outerDimsPermVec;
5322 int64_t value;
5323 if (parser.parseInteger(value))
5324 return failure();
5325 outerDimsPermVec.push_back(value);
5326 return success();
5327 }))
5328 return failure();
5329 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5330 }
5331
5332 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5333 return failure();
5334
5335 SmallVector<int64_t> innerDimsPosVec;
5337 int64_t value;
5338 if (parser.parseInteger(value))
5339 return failure();
5340 innerDimsPosVec.push_back(value);
5341 return success();
5342 }))
5343 return failure();
5344 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5345
5346 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5347 return failure();
5348
5349 DenseI64ArrayAttr staticTilesAttr;
5350 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5351 return failure();
5352 for (auto val : staticTilesAttr.asArrayRef())
5353 staticTiles.push_back(val);
5354
5355 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5356 return failure();
5357
5358 if (parser.parseOptionalAttrDict(result.attributes))
5359 return failure();
5360
5361 if (parser.parseColon() || parser.parseType(sourceType))
5362 return failure();
5363
5364 bool hasArrow = succeeded(parser.parseOptionalArrow());
5365 if (hasArrow) {
5366 if (parser.parseType(destType))
5367 return failure();
5368 }
5369
5370 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5371 if (!hasArrow) {
5372 return parser.emitError(parser.getCurrentLocation(),
5373 "pack/unpack requires '->' and destination type");
5374 }
5375
5376 if (!isMemRef)
5377 resultType = destType;
5378
5379 if (parser.resolveOperand(source, sourceType, result.operands) ||
5380 parser.resolveOperand(dest, destType, result.operands))
5381 return failure();
5382
5383 if (!paddingValue.empty() &&
5384 parser.resolveOperands(paddingValue, paddingValueType[0],
5385 result.operands))
5386 return failure();
5387
5388 if (!dynamicTiles.empty() &&
5389 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5390 result.operands))
5391 return failure();
5392
5393 result.addAttribute("static_inner_tiles",
5394 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5395 result.addAttribute("inner_dims_pos", innerDimsPos);
5396 if (outerDimsPerm)
5397 result.addAttribute("outer_dims_perm", outerDimsPerm);
5398
5399 SmallVector<int32_t> segmentSizes = {
5400 1, 1, static_cast<int32_t>(paddingValue.size()),
5401 static_cast<int32_t>(dynamicTiles.size())};
5402 result.addAttribute("operandSegmentSizes",
5403 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5404
5405 if (!isMemRef)
5406 result.addTypes(resultType);
5407
5408 return success();
5409}
5410
5411void PackOp::print(OpAsmPrinter &p) {
5412 p << " " << getSource();
5413
5414 if (getPaddingValue()) {
5415 p << " padding_value(" << getPaddingValue() << " : "
5416 << getPaddingValue().getType() << ")";
5417 }
5418
5419 if (!getOuterDimsPerm().empty()) {
5420 p << " outer_dims_perm = [";
5421 llvm::interleaveComma(getOuterDimsPerm(), p);
5422 p << "]";
5423 }
5424
5425 p << " inner_dims_pos = [";
5426 llvm::interleaveComma(getInnerDimsPos(), p);
5427 p << "]";
5428
5429 p << " inner_tiles = ";
5430 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5431
5432 p << " into " << getDest();
5433
5434 p.printOptionalAttrDict((*this)->getAttrs(),
5435 {"static_inner_tiles", "inner_dims_pos",
5436 "outer_dims_perm", "operandSegmentSizes"});
5437
5438 p << " : " << getSource().getType();
5439 p << " -> " << getDest().getType();
5440}
5441
5442void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5443 Value dest, ArrayRef<int64_t> innerDimsPos,
5444 ArrayRef<OpFoldResult> innerTiles,
5445 std::optional<Value> paddingValue,
5446 ArrayRef<int64_t> outerDimsPerm) {
5447 assert(innerDimsPos.size() == innerTiles.size() &&
5448 "number of tile sizes specified must match the specified number of "
5449 "original dimensions to be tiled");
5450 SmallVector<int64_t> staticTileSizes;
5451 SmallVector<Value> dynamicTileSizes;
5452 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5453 build(builder, state, dest.getType(), source, dest,
5454 paddingValue ? *paddingValue : nullptr,
5455 outerDimsPerm.empty() ? nullptr
5456 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5457 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5458 builder.getDenseI64ArrayAttr(staticTileSizes));
5459}
5460
5461LogicalResult
5462PackOp::reifyResultShapes(OpBuilder &builder,
5463 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5464 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5465}
5466
5467DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5468 return getDimAndTileMappingImpl(*this);
5469}
5470
5471SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5472 return getMixedTilesImpl(*this);
5473}
5474
5475SmallVector<int64_t> PackOp::getStaticTiles() {
5476 return getStaticTilesImpl(*this);
5477}
5478
5479ArrayRef<int64_t> PackOp::getAllOuterDims() {
5480 ShapedType inputType = getSourceType();
5481 int64_t inputRank = inputType.getRank();
5482 return getDestType().getShape().take_front(inputRank);
5483}
5484
5485SmallVector<int64_t> PackOp::getTiledOuterDims() {
5486 auto innerDimsPos = getInnerDimsPos();
5487 SmallVector<int64_t> outerDims(getAllOuterDims());
5488 SmallVector<int64_t> res;
5489
5490 // Recover the original order of the outer dims.
5491 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5492 invertPermutationVector(outerDimPermInv);
5493 if (!outerDimPermInv.empty())
5494 applyPermutationToVector(outerDims, outerDimPermInv);
5495
5496 // Collect the outer dims corresponding to the tilled inner dims.
5497 for (auto index : innerDimsPos)
5498 res.push_back(outerDims[index]);
5499
5500 return res;
5501}
5502
5503bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5504 ArrayRef<int64_t> innerDimsPos,
5505 ArrayRef<int64_t> outputShape,
5506 ArrayRef<int64_t> outerDimsPerm,
5507 ArrayRef<OpFoldResult> innerTiles) {
5508 SmallVector<int64_t> outputTileSizes(
5509 outputShape.take_front(inputShape.size()));
5510 if (!outerDimsPerm.empty()) {
5511 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5512 "expected output and outer_dims_perm to have same size");
5513 applyPermutationToVector(outputTileSizes,
5514 invertPermutationVector(outerDimsPerm));
5515 }
5516 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5517 if (ShapedType::isDynamic(inputShape[pos]))
5518 continue;
5519 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5520 if (!constantTile) {
5521 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5522 (inputShape[pos] % outputTileSizes[pos] != 0))
5523 return true;
5524 } else {
5525 assert(*constantTile != 0 && "static tile size can't be zero");
5526 if (inputShape[pos] % (*constantTile) != 0) {
5527 return true;
5528 }
5529 }
5530 }
5531 return false;
5532}
5533
5534bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5535 ArrayRef<int64_t> innerDimsPos,
5536 ArrayRef<int64_t> outputShape,
5537 ArrayRef<int64_t> outerDimsPerm,
5538 ArrayRef<OpFoldResult> innerTiles) {
5539 SmallVector<int64_t> outputTileSizes(
5540 outputShape.take_front(inputShape.size()));
5541 if (!outerDimsPerm.empty()) {
5542 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5543 "expected output and outer_dims_perm to have same size");
5544 applyPermutationToVector(outputTileSizes,
5545 invertPermutationVector(outerDimsPerm));
5546 }
5547 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5548 if (ShapedType::isDynamic(inputShape[pos]) ||
5549 ShapedType::isDynamic(outputTileSizes[pos]))
5550 return true;
5551 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5552 if (!constantTile)
5553 return true;
5554 assert(*constantTile != 0 && "static tile size can't be zero");
5555 if (inputShape[pos] % (*constantTile) != 0)
5556 return true;
5557 }
5558 return false;
5559}
5560
5561LogicalResult PackOp::verify() {
5563 return failure();
5564
5565 // Verify padding value, and bail out if the tile does not divide the
5566 // dimension fully. In the case of dynamic tile factors or dimensions, having
5567 // a partial tile is undefined behavior.
5568 auto paddingValue = getPaddingValue();
5569 if (paddingValue &&
5570 paddingValue.getType() != getSourceType().getElementType()) {
5571 return emitOpError("expected padding_value has ")
5572 << getSourceType().getElementType()
5573 << " but got: " << paddingValue.getType();
5574 }
5575
5576 if (!paddingValue &&
5577 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5578 getDestType().getShape(), getOuterDimsPerm(),
5579 getMixedTiles())) {
5580 return emitOpError(
5581 "invalid tile factor or output size provided. Only full tiles are "
5582 "supported when padding_value is not set");
5583 }
5584 return success();
5585}
5586
5587/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5588/// Value's to kDynamic, even if they are arith.constant values.
5589static SmallVector<int64_t>
5592 for (auto o : ofrs) {
5593 // Have to do this first, as getConstantIntValue special-cases constants.
5594 if (llvm::dyn_cast_if_present<Value>(o))
5595 result.push_back(ShapedType::kDynamic);
5596 else
5597 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5598 }
5599 return result;
5600}
5601
5602SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5603 ArrayRef<int64_t> innerTileSizes,
5604 ArrayRef<int64_t> innerDimsPos,
5605 ArrayRef<int64_t> outerDimsPerm) {
5606 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5607 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5608 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5609 continue;
5610 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5611 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5612 continue;
5613 }
5614 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5615 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5616 }
5617
5618 // Swap tile loops if outer_dims_perm is available.
5619 if (!outerDimsPerm.empty())
5620 applyPermutationToVector(resultShape, outerDimsPerm);
5621
5622 // Append the inner tile dimensions.
5623 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5624 return resultShape;
5625}
5626
5627SmallVector<OpFoldResult> PackOp::getResultShape(
5628 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5629 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5630 ArrayRef<int64_t> outerDimsPerm) {
5631 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5632
5633 AffineExpr s0, s1;
5634 bindSymbols(builder.getContext(), s0, s1);
5635 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5636 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5637 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5638 builder, loc, ceilDivExpr,
5639 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5640 }
5641 if (!outerDimsPerm.empty())
5642 applyPermutationToVector(resultDims, outerDimsPerm);
5643 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5644
5645 SmallVector<int64_t> resultTypeShape =
5646 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5647 asShapeWithAnyValueAsDynamic(innerTileSizes),
5648 innerDimsPos, outerDimsPerm);
5649
5650 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5651 // result type shape says it's a dynamic dim. This is needed as callers may
5652 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5653 // dynamic dims returned by that.
5654 for (unsigned i = 0; i < resultDims.size(); ++i) {
5655 if (ShapedType::isStatic(resultTypeShape[i]))
5656 continue;
5657 resultDims[i] =
5658 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5659 }
5660
5661 return resultDims;
5662}
5663
5664RankedTensorType PackOp::inferPackedTensorType(
5665 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5666 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5667 SmallVector<int64_t> resultShape = inferPackedShape(
5668 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5669 return RankedTensorType::get(resultShape, sourceType.getElementType());
5670}
5671
5672MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5673 ArrayRef<int64_t> innerTileSizes,
5674 ArrayRef<int64_t> innerDimsPos,
5675 ArrayRef<int64_t> outerDimsPerm) {
5676 SmallVector<int64_t> resultShape = inferPackedShape(
5677 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5678 return MemRefType::get(resultShape, sourceType.getElementType());
5679}
5680
5681Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5682 ArrayRef<OpFoldResult> innerTileSizes,
5683 ArrayRef<int64_t> innerDimsPos,
5684 ArrayRef<int64_t> outerDimsPerm) {
5685 AffineExpr dim0, dim1;
5686 bindDims(b.getContext(), dim0, dim1);
5687 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5688 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5689 {v1, v2});
5690 };
5691
5692 SmallVector<OpFoldResult> mixedSizes;
5693 for (auto [index, value] : llvm::enumerate(
5694 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5695 if (ShapedType::isDynamic(value))
5696 mixedSizes.push_back(
5697 tensor::DimOp::create(b, loc, source, index).getResult());
5698 else
5699 mixedSizes.push_back(b.getIndexAttr(value));
5700 }
5701 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5702 int64_t dimPos = std::get<0>(it);
5703 OpFoldResult tileSize = std::get<1>(it);
5704 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5705 }
5706 if (!outerDimsPerm.empty())
5707 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5708
5709 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5710 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5711 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5712}
5713
5714PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5715 ArrayRef<int64_t> innerPermutation,
5716 ArrayRef<int64_t> outerPermutation) {
5717 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5718 *this, innerPermutation, outerPermutation);
5719 Value transposedDest =
5720 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5721 metadata.innerDimsPos, metadata.outerDimsPerm);
5722 return PackOp::create(b, loc, getSource(), transposedDest,
5723 metadata.innerDimsPos, metadata.innerTiles,
5724 getPaddingValue(), metadata.outerDimsPerm);
5725}
5726
5727template <typename OpTy>
5730 &effects) {
5731 // No memory effects for pure tensor semantics
5732 if (op.hasPureTensorSemantics())
5733 return;
5734
5735 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5736 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5737 continue;
5738
5739 if (&opOperand == &op.getSourceMutable()) {
5740 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5741 /*effectOnFullRegion=*/true,
5743 } else if (&opOperand == &op.getDestMutable()) {
5744 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5745 /*effectOnFullRegion=*/true,
5747 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5748 /*effectOnFullRegion=*/true,
5750 }
5751 }
5752}
5753
5754void PackOp::getEffects(
5756 &effects) {
5757 getPackUnPackEffectsImpl(*this, effects);
5758}
5759
5760void UnPackOp::getEffects(
5762 &effects) {
5763 getPackUnPackEffectsImpl(*this, effects);
5764}
5765
5766/// Returns true if the tiles and the tiled dims are constant.
5767template <typename OpTy>
5769 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5770 "applies to only pack or unpack operations");
5771 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5772 ? op.getDestType()
5773 : op.getSourceType();
5774 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5775 for (auto [dimDest, tile] : llvm::zip(
5776 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5777 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5778 if (!constTileSize || ShapedType::isDynamic(dimDest))
5779 return false;
5780 }
5781 return true;
5782}
5783
5784Speculation::Speculatability PackOp::getSpeculatability() {
5785 if (!hasPureTensorSemantics())
5787 if (getPaddingValue())
5789
5790 // The verifier rejects already operations if we can statically prove that the
5791 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5792 // to have constant tiles and tiled inner dimensions.
5795
5797}
5798
5799// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5800// dimensions for pack and unpack.
5801static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5802 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5803 return false;
5804 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5805 return true;
5806 // Outer dims permutation is optional.
5807 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5808 // identity permutation.
5809 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5810 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5811}
5812
5813// Return true if pack and unpack have the same tiles.
5814// Same SSA values or same integer constants.
5815static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5816 auto packTiles = packOp.getMixedTiles();
5817 auto unPackTiles = unPackOp.getMixedTiles();
5818 if (packTiles.size() != unPackTiles.size())
5819 return false;
5820 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5821 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5822 return false;
5823 }
5824 return true;
5825}
5826
5827/// Returns true if the pack op does not need a padding value.
5828static bool paddingIsNotNeeded(PackOp op) {
5829 auto srcType = op.getSourceType();
5830 if (llvm::any_of(op.getInnerDimsPos(),
5831 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5832 return false;
5833 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5834 return false;
5835 return !PackOp::requirePaddingValue(
5836 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5837 op.getOuterDimsPerm(), op.getMixedTiles());
5838}
5839
5840/// Returns true if the `srcShape` or `destShape` is different from the one in
5841/// `packOp` and populates each with the inferred static shape.
5842static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5843 SmallVectorImpl<int64_t> &destShape) {
5844 bool changeNeeded = false;
5845 srcShape.assign(packOp.getSourceType().getShape().begin(),
5846 packOp.getSourceType().getShape().end());
5847 destShape.assign(packOp.getDestType().getShape().begin(),
5848 packOp.getDestType().getShape().end());
5849 llvm::SmallSetVector<int64_t, 4> innerDims;
5850 innerDims.insert_range(packOp.getInnerDimsPos());
5851 SmallVector<int64_t> inverseOuterDimsPerm;
5852 if (!packOp.getOuterDimsPerm().empty())
5853 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5854 int srcRank = packOp.getSourceRank();
5855 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5856 if (innerDims.contains(i))
5857 continue;
5858 int64_t srcPos = i;
5859 int64_t destPos = i;
5860 if (!inverseOuterDimsPerm.empty())
5861 destPos = inverseOuterDimsPerm[srcPos];
5862 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5863 ShapedType::isDynamic(destShape[destPos])) {
5864 continue;
5865 }
5866 int64_t size = srcShape[srcPos];
5867 if (ShapedType::isDynamic(size))
5868 size = destShape[destPos];
5869 srcShape[srcPos] = size;
5870 destShape[destPos] = size;
5871 changeNeeded = true;
5872 }
5873 return changeNeeded;
5874}
5875
5876LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5877 // TODO: Support Memref PackOp. Temporarily return failure.
5878 if (!packOp.hasPureTensorSemantics())
5879 return failure();
5880
5881 // Fold an pack(unpack(x)) to x.
5882 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5883 if (unPackOp.getSourceType() == packOp.getDestType() &&
5884 !packOp.getPaddingValue() &&
5885 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5886 haveSameTiles(packOp, unPackOp)) {
5887 rewriter.replaceOp(packOp, unPackOp.getSource());
5888 return success();
5889 }
5890 }
5891
5892 // Fold optional PaddingValue operand away if padding is not needed.
5893 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5894 rewriter.startOpModification(packOp);
5895 packOp.getPaddingValueMutable().clear();
5896 rewriter.finalizeOpModification(packOp);
5897 return success();
5898 }
5899
5900 // Insert tensor.cast ops if static shape inference is available..
5901 SmallVector<int64_t> srcShape, destShape;
5902 if (inferStaticShape(packOp, srcShape, destShape)) {
5903 Location loc = packOp.getLoc();
5904 Value source = packOp.getSource();
5905 if (srcShape != packOp.getSourceType().getShape()) {
5906 auto newSrcType = packOp.getSourceType().clone(srcShape);
5907 source =
5908 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5909 }
5910 Value dest = packOp.getDest();
5911 ShapedType originalResultType = packOp.getDestType();
5912 bool needUpdateDestType = (destShape != originalResultType.getShape());
5913 if (needUpdateDestType) {
5914 auto newDestType = packOp.getDestType().clone(destShape);
5915 dest =
5916 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5917 }
5918 rewriter.modifyOpInPlace(packOp, [&] {
5919 packOp.getSourceMutable().assign(source);
5920 packOp.getDestMutable().assign(dest);
5921 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5922 });
5923 // Insert a cast if needed
5924 if (needUpdateDestType) {
5925 rewriter.setInsertionPointAfter(packOp);
5926 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5927 packOp.getResult());
5928 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5929 }
5930 return success();
5931 }
5932
5933 return failure();
5934}
5935
5936template <typename PackOrUnpackOp>
5937static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5938 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5939 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5940 "Function meant for pack/unpack");
5941 // This is a pad if packing only adds ones and we don't transpose dimensions.
5942
5943 // Check that we are not transposing any dimensions.
5944 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5945 int64_t numPackedDims = innerDimsPos.size();
5946 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5947 if (orderedDims != innerDimsPos) {
5948 // Dimensions don't happen in order.
5949 return false;
5950 }
5951
5952 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5953 int64_t packedRank = packedTensorType.getRank();
5954 // At this point we know that we are taking numPackedDims outer
5955 // dimensions and pushing them all the way as the inner most dimensions.
5956 // What's left on the outer most dimensions is, in this order:
5957 // - the factor of the packed dimensions, then
5958 // - the untouched dimensions
5959 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5960 // if all the dimensions that bubble outerward are ones.
5961 // Therefore check that all the dimensions but the numPackedDims inner most
5962 // ones are ones.
5963 return llvm::all_of(
5964 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5965 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5966}
5967
5968bool PackOp::isLikePad() {
5969 auto packedTensorType =
5970 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5971 return isLikePadUnPad(*this, packedTensorType);
5972}
5973
5974::mlir::LogicalResult
5975PackOp::fold(FoldAdaptor adaptor,
5977 if (!hasPureTensorSemantics())
5978 return failure();
5979 std::optional<Attribute> paddingValue;
5980 if (auto pad = adaptor.getPaddingValue())
5981 paddingValue = pad;
5982 if (OpFoldResult reshapedSource = reshapeConstantSource(
5983 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5984 cast<TensorType>(getDestType()), paddingValue)) {
5985 results.push_back(reshapedSource);
5986 return success();
5987 }
5988 return failure();
5989}
5990
5991/// Folds a tensor.cast op into a consuming PackOp op if the
5992/// `tensor.cast` has source that is more static than the consuming op.
5993///
5994/// Example:
5995/// ```mlir
5996/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5997/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5998/// ```
5999///
6000/// folds into:
6001///
6002/// ```mlir
6003/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
6004/// ```
6007
6008 LogicalResult matchAndRewrite(PackOp op,
6009 PatternRewriter &rewriter) const override {
6010 // TODO: Support Memref PackOp. Temporarily return failure.
6011 if (!op.hasPureTensorSemantics())
6012 return failure();
6013
6015 return failure();
6016
6017 SmallVector<Type> newResultTypes(op->getResultTypes());
6018 SmallVector<Value> newOperands =
6020
6021 // Get the updated mixed-tile-sizes attribute.
6022 SmallVector<OpFoldResult> newMixedTileSizes =
6023 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6024 if (llvm::any_of(newMixedTileSizes, isZeroInteger))
6025 return failure();
6026
6027 // Clone op.
6028 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6029 // this point. However, in practice, we use them for things that we'd like
6030 // to preserve. Implement a better abstraction.
6031 PackOp newOp =
6032 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6033 op.getInnerDimsPos(), newMixedTileSizes,
6034 op.getPaddingValue(), op.getOuterDimsPerm());
6035 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6036
6037 // Replace op.
6038 Value oldResult = op.getResult();
6039 Value newResult = newOp.getResult();
6041 (newResult.getType() != oldResult.getType())
6042 ? tensor::CastOp::create(rewriter, op->getLoc(),
6043 oldResult.getType(), newResult)
6044 : newResult;
6045
6046 rewriter.replaceOp(op, {replacement});
6047
6048 return success();
6049 }
6050};
6051
6052//===----------------------------------------------------------------------===//
6053// UnPackOp
6054//===----------------------------------------------------------------------===//
6055
6056void UnPackOp::getAsmResultNames(
6057 function_ref<void(Value, StringRef)> setNameFn) {
6058 if (!getResults().empty())
6059 setNameFn(getResult(), "unpack");
6060}
6061
6062// Custom parser for UnPackOp that handles the memref/tensor case distinction
6063ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6064 OpAsmParser::UnresolvedOperand source, dest;
6066 SmallVector<int64_t> staticTiles;
6067 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6068 Type sourceType, destType, resultType;
6069
6070 if (parser.parseOperand(source))
6071 return failure();
6072
6073 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6074 if (parser.parseEqual())
6075 return failure();
6076
6077 SmallVector<int64_t> outerDimsPermVec;
6079 int64_t value;
6080 if (parser.parseInteger(value))
6081 return failure();
6082 outerDimsPermVec.push_back(value);
6083 return success();
6084 }))
6085 return failure();
6086 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6087 }
6088
6089 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6090 return failure();
6091
6092 SmallVector<int64_t> innerDimsPosVec;
6094 int64_t value;
6095 if (parser.parseInteger(value))
6096 return failure();
6097 innerDimsPosVec.push_back(value);
6098 return success();
6099 }))
6100 return failure();
6101 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6102
6103 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6104 return failure();
6105
6106 DenseI64ArrayAttr staticTilesAttr;
6107 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6108 return failure();
6109 for (auto val : staticTilesAttr.asArrayRef())
6110 staticTiles.push_back(val);
6111
6112 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6113 return failure();
6114
6115 if (parser.parseOptionalAttrDict(result.attributes))
6116 return failure();
6117
6118 if (parser.parseColon() || parser.parseType(sourceType))
6119 return failure();
6120
6121 bool hasArrow = succeeded(parser.parseOptionalArrow());
6122 if (hasArrow) {
6123 if (parser.parseType(destType))
6124 return failure();
6125 }
6126
6127 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6128 if (!hasArrow) {
6129 return parser.emitError(parser.getCurrentLocation(),
6130 "pack/unpack requires '->' and destination type");
6131 }
6132
6133 if (!isMemRef)
6134 resultType = destType;
6135
6136 if (parser.resolveOperand(source, sourceType, result.operands) ||
6137 parser.resolveOperand(dest, destType, result.operands))
6138 return failure();
6139
6140 if (!dynamicTiles.empty() &&
6141 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6142 result.operands))
6143 return failure();
6144
6145 result.addAttribute("static_inner_tiles",
6146 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6147 result.addAttribute("inner_dims_pos", innerDimsPos);
6148 if (outerDimsPerm)
6149 result.addAttribute("outer_dims_perm", outerDimsPerm);
6150
6151 SmallVector<int32_t> segmentSizes = {
6152 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6153 result.addAttribute("operandSegmentSizes",
6154 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6155
6156 if (!isMemRef)
6157 result.addTypes(resultType);
6158
6159 return success();
6160}
6161
6162void UnPackOp::print(OpAsmPrinter &p) {
6163 p << " " << getSource();
6164
6165 if (!getOuterDimsPerm().empty()) {
6166 p << " outer_dims_perm = [";
6167 llvm::interleaveComma(getOuterDimsPerm(), p);
6168 p << "]";
6169 }
6170
6171 p << " inner_dims_pos = [";
6172 llvm::interleaveComma(getInnerDimsPos(), p);
6173 p << "]";
6174
6175 p << " inner_tiles = ";
6176 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6177
6178 p << " into " << getDest();
6179
6180 p.printOptionalAttrDict((*this)->getAttrs(),
6181 {"static_inner_tiles", "inner_dims_pos",
6182 "outer_dims_perm", "operandSegmentSizes"});
6183
6184 p << " : " << getSource().getType();
6185 p << " -> " << getDest().getType();
6186}
6187
6188LogicalResult
6189UnPackOp::reifyResultShapes(OpBuilder &builder,
6190 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6191 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6192}
6193
6194DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6195 return getDimAndTileMappingImpl(*this);
6196}
6197
6198SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6199 return getMixedTilesImpl(*this);
6200}
6201
6202SmallVector<int64_t> UnPackOp::getStaticTiles() {
6203 return getStaticTilesImpl(*this);
6204}
6205
6206ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6207 ShapedType destType = getDestType();
6208 int64_t destRank = destType.getRank();
6209 return getSourceType().getShape().take_front(destRank);
6210}
6211
6212SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6213 auto innerDimsPos = getInnerDimsPos();
6214 SmallVector<int64_t> outerDims(getAllOuterDims());
6215 SmallVector<int64_t> res;
6216
6217 // Recover the original order of the outer dims.
6218 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6219 invertPermutationVector(outerDimPermInv);
6220 if (!outerDimPermInv.empty())
6221 applyPermutationToVector(outerDims, outerDimPermInv);
6222
6223 // Collect the outer dims corresponding to the tilled inner dims.
6224 for (auto index : innerDimsPos)
6225 res.push_back(outerDims[index]);
6226
6227 return res;
6228}
6229
6230LogicalResult UnPackOp::verify() {
6231 return commonVerifierPackAndUnPackOp(*this);
6232}
6233
6234Speculation::Speculatability UnPackOp::getSpeculatability() {
6235 if (!hasPureTensorSemantics())
6237 // See PackOp::getSpeculatability.
6240
6242}
6243
6244void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6245 Value dest, ArrayRef<int64_t> innerDimsPos,
6246 ArrayRef<OpFoldResult> innerTiles,
6247 ArrayRef<int64_t> outerDimsPerm) {
6248 assert(innerDimsPos.size() == innerTiles.size() &&
6249 "number of tile sizes specified must match the specified number of "
6250 "original dimensions to be tiled");
6251 SmallVector<int64_t> staticTileSizes;
6252 SmallVector<Value> dynamicTileSizes;
6253 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6254 build(builder, state, dest.getType(), source, dest,
6255 outerDimsPerm.empty() ? nullptr
6256 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6257 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6258 builder.getDenseI64ArrayAttr(staticTileSizes));
6259}
6260
6261Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6262 Value source,
6263 ArrayRef<OpFoldResult> innerTileSizes,
6264 ArrayRef<int64_t> innerDimsPos,
6265 ArrayRef<int64_t> outerDimsPerm) {
6266 AffineExpr sym0, sym1;
6267 bindSymbols(b.getContext(), sym0, sym1);
6268 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6269 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6270 };
6271
6272 SmallVector<OpFoldResult> mixedSizes;
6273 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6274 for (auto i :
6275 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6276 if (srcType.isDynamicDim(i))
6277 mixedSizes.push_back(
6278 tensor::DimOp::create(b, loc, source, i).getResult());
6279 else
6280 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6281 }
6282 if (!outerDimsPerm.empty()) {
6284 mixedSizes, invertPermutationVector(outerDimsPerm));
6285 }
6286
6287 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6288 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6289
6290 auto elemType = srcType.getElementType();
6291 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6292}
6293
6294UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6295 Value transposedSource,
6296 ArrayRef<int64_t> innerPermutation,
6297 ArrayRef<int64_t> outerPermutation) {
6298 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6299 *this, innerPermutation, outerPermutation);
6300 return UnPackOp::create(b, loc, transposedSource, getDest(),
6301 metadata.innerDimsPos, metadata.innerTiles,
6302 metadata.outerDimsPerm);
6303}
6304
6305/// Returns true if the `srcShape` or `destShape` is different from the one in
6306/// `op` and populates each with the inferred static shape.
6307static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6308 SmallVectorImpl<int64_t> &destShape) {
6309 bool changeNeeded = false;
6310 srcShape.assign(op.getSourceType().getShape().begin(),
6311 op.getSourceType().getShape().end());
6312 destShape.assign(op.getDestType().getShape().begin(),
6313 op.getDestType().getShape().end());
6314 llvm::SmallSetVector<int64_t, 4> innerDims;
6315 innerDims.insert_range(op.getInnerDimsPos());
6316 SmallVector<int64_t> inverseOuterDimsPerm;
6317 if (!op.getOuterDimsPerm().empty())
6318 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6319 int destRank = op.getDestRank();
6320 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6321 if (innerDims.contains(i))
6322 continue;
6323 int64_t srcPos = i;
6324 int64_t destPos = i;
6325 if (!inverseOuterDimsPerm.empty())
6326 srcPos = inverseOuterDimsPerm[destPos];
6327 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6328 ShapedType::isDynamic(destShape[destPos])) {
6329 continue;
6330 }
6331 int64_t size = srcShape[srcPos];
6332 if (ShapedType::isDynamic(size))
6333 size = destShape[destPos];
6334 srcShape[srcPos] = size;
6335 destShape[destPos] = size;
6336 changeNeeded = true;
6337 }
6338 return changeNeeded;
6339}
6340
6341LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6342 PatternRewriter &rewriter) {
6343 // TODO: Support Memref UnPackOp. Temporarily return failure.
6344 if (!unPackOp.hasPureTensorSemantics())
6345 return failure();
6346
6347 /// unpack(pack(x)) -> x
6348 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6349 if (packOp.getSourceType() != unPackOp.getDestType())
6350 return failure();
6351 if (packOp.getPaddingValue() ||
6352 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6353 !haveSameTiles(packOp, unPackOp))
6354 return failure();
6355 rewriter.replaceOp(unPackOp, packOp.getSource());
6356 return success();
6357 }
6358 /// unpack(destinationStyleOp(x)) -> unpack(x)
6359 if (auto dstStyleOp =
6360 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6361 auto destValue = cast<OpResult>(unPackOp.getDest());
6362 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6363 rewriter.modifyOpInPlace(unPackOp,
6364 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6365 return success();
6366 }
6367 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6368 if (unPackOp->hasOneUse()) {
6369 auto extractSliceUser =
6370 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6371 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6372 OpBuilder::InsertionGuard g(rewriter);
6373 rewriter.setInsertionPoint(unPackOp);
6374 auto newDest = tensor::ExtractSliceOp::create(
6375 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6376 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6377 extractSliceUser.getMixedStrides());
6378 rewriter.modifyOpInPlace(unPackOp, [&]() {
6379 unPackOp.setDpsInitOperand(0, newDest);
6380 unPackOp.getResult().setType(newDest.getType());
6381 });
6382 rewriter.replaceOp(extractSliceUser, unPackOp);
6383 return success();
6384 }
6385 }
6386
6387 // Insert tensor.cast ops if static shape inference is available..
6388 SmallVector<int64_t> srcShape, destShape;
6389 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6390 Location loc = unPackOp.getLoc();
6391 Value source = unPackOp.getSource();
6392 if (srcShape != unPackOp.getSourceType().getShape()) {
6393 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6394 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6395 unPackOp.getSource());
6396 }
6397 Value dest = unPackOp.getDest();
6398 if (destShape != unPackOp.getDestType().getShape()) {
6399 auto newDestType = unPackOp.getDestType().clone(destShape);
6400 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6401 unPackOp.getDest());
6402 }
6403 UnPackOp newOp = UnPackOp::create(
6404 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6405 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6406 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6407 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6408 return success();
6409 }
6410
6411 return failure();
6412}
6413
6414bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6415 // Rank-reduced folding is not supported.
6416 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6417 return false;
6418 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6419 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6420 return false;
6421 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6422 SmallVector<int64_t> outerShapeWithoutTranspose =
6424 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6425 for (auto [pos, tileSize] :
6426 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6427 areOuterDimsTiled[pos] = true;
6428 if (unpackedTypeAfterFold.isDynamicDim(pos))
6429 return false;
6430 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6431 return false;
6432 if (ShapedType::isDynamic(tileSize))
6433 return false;
6434 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6435 unpackedTypeAfterFold.getDimSize(pos);
6436 if (paddingSize >= tileSize)
6437 return false;
6438 }
6439 // extract_slice must not affect dimensions that are not being unpacked
6440 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6441 if (areOuterDimsTiled[pos])
6442 continue;
6443 int64_t dim = outerShapeWithoutTranspose[pos];
6444 if (ShapedType::isDynamic(dim))
6445 return false;
6446 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6447 return false;
6448 }
6449 return true;
6450}
6451
6452bool UnPackOp::isLikeUnPad() {
6453 ShapedType packedTensorType = getSourceType();
6454 return isLikePadUnPad(*this, packedTensorType);
6455}
6456
6457::mlir::LogicalResult
6458UnPackOp::fold(FoldAdaptor adaptor,
6459 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6460 // TODO: Support Memref UnPackOp. Temporarily return failure.
6461 if (!hasPureTensorSemantics())
6462 return failure();
6463
6464 if (OpFoldResult reshapedSource = reshapeConstantSource(
6465 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6466 cast<TensorType>(getResult().getType()))) {
6467 results.push_back(reshapedSource);
6468 return success();
6469 }
6470 return failure();
6471}
6472
6473/// Folds a tensor.cast op into a consuming UnPackOp op if the
6474/// `tensor.cast` has source that is more static than the consuming op.
6475///
6476/// Example:
6477/// ```mlir
6478/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6479/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6480/// ```
6481///
6482/// folds into:
6483///
6484/// ```mlir
6485/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6486/// ```
6487struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6488 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6489
6490 LogicalResult matchAndRewrite(UnPackOp op,
6491 PatternRewriter &rewriter) const override {
6492 // TODO: Support Memref UnPackOp. Temporarily return failure.
6493 if (!op.hasPureTensorSemantics())
6494 return failure();
6495
6497 return failure();
6498
6499 SmallVector<Type> newResultTypes(op->getResultTypes());
6500 SmallVector<Value> newOperands =
6502 Value sourceTensor = newOperands[0];
6503
6504 // Get the updated mixed-tile-sizes attribute.
6506 rewriter, sourceTensor.getType(), op.getMixedTiles());
6507
6508 // Clone op.
6509 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6510 // this point. However, in practice, we use them for things that we'd like
6511 // to preserve. Implement a better abstraction.
6512 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6513 newOperands[1], op.getInnerDimsPos(),
6514 newMixedTileSizes, op.getOuterDimsPerm());
6515 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6516
6517 // Replace op.
6518 Value oldResult = op.getResult();
6519 Value newResult = newOp.getResult();
6521 (newResult.getType() != oldResult.getType())
6522 ? tensor::CastOp::create(rewriter, op->getLoc(),
6523 oldResult.getType(), newResult)
6524 : newResult;
6525
6526 rewriter.replaceOp(op, {replacement});
6527
6528 return success();
6529 }
6530};
6531
6532//===----------------------------------------------------------------------===//
6533// BatchReduceMatmulOp
6534//===----------------------------------------------------------------------===//
6535SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6537 utils::IteratorType::reduction, utils::IteratorType::parallel,
6538 utils::IteratorType::parallel, utils::IteratorType::reduction};
6539}
6540
6541SmallVector<AffineMap>
6542BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6543 AffineExpr d0, d1, d2, d3;
6544 SmallVector<AffineMap> indexingMaps;
6545 bindDims(context, d0, d1, d2, d3);
6546 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6547 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6548 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6549 return indexingMaps;
6550}
6551
6552bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6553 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6554 if (!maps)
6555 return false;
6556 if (maps.size() != 3)
6557 return false;
6558 auto positions = getAffineResultPositions(maps);
6559 if (failed(positions))
6560 return false;
6561 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6562 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6563 (*positions)[2] == SmallVector<int64_t>{1, 2};
6564}
6565unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6566
6567std::string BatchReduceMatmulOp::getLibraryCallName() {
6568 return generateLibraryCallName(getOperation());
6569}
6570
6571/// Check if the op has broadcast and/or transpose semantic. Returns true if
6572/// the user defined indexing maps are not equal to default map.
6573bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6574 SmallVector<AffineMap, 3> defaultMaps =
6575 getDefaultIndexingMaps(this->getContext());
6576 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6577 return defaultMaps != explicitMaps;
6578}
6579
6580/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6581/// broadcast map must include K dimension.
6582/// TODO: Strict inclusion of K dimension in the broadcast map is not
6583/// necessary for both input matrices simultaneously. We can relax this
6584/// condition to have K dimension for one input matrix map and infer the K
6585/// dimension for other input matrix map from the one already having K
6586/// dimension.
6587bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6588 bool isLHS) {
6589 assert(bcastMap.getNumResults() < 3 &&
6590 "Expected less than 3 result dim expr.");
6591 bool isValid = false;
6592 enum Indices { batchPos, mPos, nPos, kPos };
6593 if (bcastMap.getNumResults() == 1) {
6594 AffineExpr expr = bcastMap.getResult(0);
6595 isValid = expr.isFunctionOfDim(kPos);
6596 } else if (bcastMap.getNumResults() == 2) {
6597 AffineExpr expr0 = bcastMap.getResult(0);
6598 AffineExpr expr1 = bcastMap.getResult(1);
6599 isValid =
6600 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6601 expr0.isFunctionOfDim(mPos)) &&
6602 expr1.isFunctionOfDim(kPos))
6603 : ((expr0.isFunctionOfDim(batchPos) &&
6604 expr1.isFunctionOfDim(kPos)) ||
6605 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6606 }
6607 return isValid;
6608}
6609
6610void BatchReduceMatmulOp::regionBuilder(
6611 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
6612 function_ref<InFlightDiagnostic()> emitError) {
6613 if (emitError && block.getNumArguments() != 3) {
6614 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6615 << block.getNumArguments();
6616 return;
6617 }
6618 assert(block.getNumArguments() == 3 &&
6619 "BatchReduceMatmulOp regionBuilder expects 3 args");
6620 RegionBuilderHelper helper(b, block);
6621 SmallVector<Value> yields;
6622
6623 auto toType = block.getArgument(2).getType();
6624 Value castValA =
6625 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6626 Value castValB =
6627 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6628 Value mulVal =
6629 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6630 if (!castValA || !castValB || !mulVal)
6631 return;
6632 Value addVal =
6633 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6634 if (!addVal)
6635 return;
6636 yields.push_back(addVal);
6637 helper.yieldOutputs(yields);
6638}
6639
6640ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6641 OperationState &result) {
6642 SmallVector<Attribute, 3> indexingMapsAttr;
6643 Attribute mapAttr;
6644 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6645 if (parser.parseEqual())
6646 return failure();
6647 if (parser.parseLSquare())
6648 return failure();
6649
6650 do {
6651 if (parser.parseAttribute(mapAttr))
6652 return failure();
6653 if (!isa<AffineMapAttr>(mapAttr)) {
6654 return parser.emitError(parser.getCurrentLocation(),
6655 "expected affine map attribute");
6656 }
6657 indexingMapsAttr.push_back(mapAttr);
6658
6659 if (parser.parseOptionalComma())
6660 break;
6661 } while (true);
6662
6663 if (parser.parseRSquare())
6664 return failure();
6665 }
6666 // Initialize indexingMaps, if not supplied explicitly.
6667 if (indexingMapsAttr.empty()) {
6668 indexingMapsAttr = llvm::map_to_vector(
6669 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6670 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6671 }
6672 result.addAttribute("indexing_maps",
6673 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6674 return ::parseNamedStructuredOp(parser, result,
6675 BatchReduceMatmulOp::getNumRegionArgs(),
6676 BatchReduceMatmulOp::getRegionBuilder());
6677}
6678
6679void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6680 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6681 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6682 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6683
6684 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6685 p << " indexing_maps = [";
6686 llvm::interleaveComma(getIndexingMaps(), p,
6687 [&](Attribute attr) { p.printAttribute(attr); });
6688 p << "]";
6689 }
6690
6691 SmallVector<StringRef, 3> elidedAttrs = {
6692 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6693 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6694 elidedAttrs);
6695}
6696
6697/// Verify the user defined indexing maps.
6698LogicalResult BatchReduceMatmulOp::verify() {
6699 // Verification of pure batch_reduce_matmul is handled by
6700 // verifyStructuredOpInterface().
6701 if (!hasUserDefinedMaps())
6702 return success();
6703
6704 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6706 return failure();
6707 }
6708 return success();
6709}
6710LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6711 SmallVectorImpl<OpFoldResult> &) {
6712 return memref::foldMemRefCast(*this);
6713}
6714void BatchReduceMatmulOp::getEffects(
6715 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6716 &effects) {
6717 if (hasPureTensorSemantics())
6718 return;
6719 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6720}
6721
6722Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6723 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6724}
6725
6726} // namespace linalg
6727} // namespace mlir
6728
6729//===----------------------------------------------------------------------===//
6730// LinalgDialect
6731//===----------------------------------------------------------------------===//
6732
6733void LinalgDialect::getCanonicalizationPatterns(
6734 RewritePatternSet &results) const {
6735 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6736 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6737}
6738
6739Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6740 Attribute value, Type type,
6741 Location loc) {
6742 return arith::ConstantOp::materialize(builder, value, type, loc);
6743}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:59
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
result_iterator result_begin()
Definition Operation.h:442
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
unsigned getNumOperands()
Definition Operation.h:375
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:426
result_iterator result_end()
Definition Operation.h:443
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:241
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override