MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
39
40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
52#include <cassert>
53#include <optional>
54
55using namespace mlir;
56using namespace mlir::linalg;
57
58/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
60 int64_t dim) {
61 auto type = cast<ShapedType>(v.getType());
62 if (!type.isDynamicDim(dim))
63 return builder.getIndexAttr(type.getDimSize(dim));
64
65 return getAsOpFoldResult(
67 .Case([&](RankedTensorType t) -> Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
69 })
70 .Case([&](MemRefType t) -> Value {
71 return memref::DimOp::create(builder, loc, v, dim);
72 }));
73}
74
75/// Returns a memref.subview or a tensor.extract_slice based on the type of the
76/// `source`.
80 ArrayRef<OpFoldResult> strides) {
82 .Case([&](RankedTensorType t) -> Operation * {
83 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
84 strides);
85 })
86 .Case([&](MemRefType type) -> Operation * {
87 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
88 strides);
89 })
90 .Default([&](Type t) -> Operation * { return nullptr; });
91}
92
93//===----------------------------------------------------------------------===//
94// Helper functions
95//===----------------------------------------------------------------------===//
96
98 int64_t dim) {
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable("Expected MemRefType or TensorType");
104}
105
107 int64_t dim) {
108 auto shapedType = llvm::cast<ShapedType>(source.getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110 return createOrFoldDimOp(b, loc, source, dim);
111 return b.getIndexAttr(shapedType.getDimSize(dim));
112}
113
114//===----------------------------------------------------------------------===//
115// Support for named Linalg ops defined in ods-gen.
116//===----------------------------------------------------------------------===//
117
121
122/// Fills the region of a structured operation using the provided
123/// `regionBuilder`. The method is used by both named structured ops created by
124/// ods-gen and by manually defined C++ ops. It is called by both builders and
125/// parsers and creates a block with arguments corresponding to the elemental
126/// types of `inputTypes` and `outputTypes`.
127static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
128 TypeRange inputTypes, TypeRange outputTypes,
131 RegionBuilderFn regionBuilder) {
132 SmallVector<Type, 8> argTypes;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(opBuilder.getUnknownLoc());
141 }
142 }
143
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs, emitError);
152
153 // indexing_maps is an auto-generated method.
154
155 // iterator_types is an auto-generated method.
156}
157
158/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159/// The result types are derived automatically if `resultTensorTypes` is none.
160/// The body of the operation is filled using `regionBuilder`. All ods-gen
161/// created structured operations use the method to implement their builders.
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(TypeRange());
170 if (!resultTensorTypes)
171 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
173
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
177
178 state.addAttributes(attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
183
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187 state.attributes.getAttrs(), /*emitError=*/{},
188 regionBuilder);
189}
190
192 std::optional<TypeRange> resultTensorTypes,
193 ValueRange inputs, ValueRange outputs,
194 ArrayRef<NamedAttribute> attributes,
195 RegionBuilderFn regionBuilder,
196 ArrayRef<AffineMap> defaultIndexingMaps) {
197 // If indexing maps are not provided, apply the default ones.
198 if (none_of(attributes, [](NamedAttribute attr) {
199 return attr.getName() == "indexing_maps";
200 })) {
201 SmallVector<Attribute, 3> indexingMapsAttrVal;
202 indexingMapsAttrVal = llvm::map_to_vector(
203 defaultIndexingMaps,
204 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
205 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
206 }
207 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
208 attributes, regionBuilder);
209}
210
212 std::optional<TypeRange> resultTensorTypes,
213 ValueRange inputs, ValueRange outputs,
214 ArrayRef<NamedAttribute> attributes,
215 RegionBuilderFn regionBuilder,
216 ArrayRef<AffineMap> defaultIndexingMaps) {
217 // If indexing maps are not provided, apply the default ones.
218 if (none_of(attributes, [](NamedAttribute attr) {
219 return attr.getName() == "indexing_maps";
220 })) {
221 SmallVector<Attribute, 4> indexingMapsAttrVal;
222 indexingMapsAttrVal = llvm::map_to_vector(
223 defaultIndexingMaps,
224 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
225 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
226 }
227 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
228 attributes, regionBuilder);
229}
230
232 std::optional<TypeRange> resultTensorTypes,
233 ValueRange inputs, ValueRange outputs,
234 ArrayRef<NamedAttribute> attributes,
235 RegionBuilderFn regionBuilder,
236 ArrayRef<AffineMap> indexingMaps) {
237 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
238 SmallVector<Attribute, 4> indexingMapsAttrVal;
239 indexingMapsAttrVal =
240 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
241 return AffineMapAttr::get(map);
242 });
243 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
244 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
245 attributes, regionBuilder);
246}
247
248/// Common parsing used for both named structured ops created by ods-gen and by
249/// manually defined C++ ops. Does not handle regions.
250static ParseResult
252 SmallVectorImpl<Type> &inputTypes,
253 SmallVectorImpl<Type> &outputTypes,
254 bool addOperandSegmentSizes = true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
257 outputsOperands;
258
259 if (succeeded(parser.parseOptionalLess())) {
260 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
261 return failure();
262 }
263 attrsLoc = parser.getCurrentLocation();
264 if (parser.parseOptionalAttrDict(result.attributes))
265 return failure();
266
267 if (succeeded(parser.parseOptionalKeyword("ins"))) {
268 if (parser.parseLParen())
269 return failure();
270
271 inputsOperandsLoc = parser.getCurrentLocation();
272 if (parser.parseOperandList(inputsOperands) ||
273 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
274 return failure();
275 }
276
277 if (succeeded(parser.parseOptionalKeyword("outs"))) {
278 outputsOperandsLoc = parser.getCurrentLocation();
279 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
280 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
281 return failure();
282 }
283
284 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
285 result.operands) ||
286 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
287 result.operands))
288 return failure();
289
290 if (addOperandSegmentSizes) {
291 // This is a bit complex because we're trying to be backward compatible with
292 // operation syntax that mix the inherent attributes and the discardable
293 // ones in the same dictionary. If the properties are used, we append the
294 // operandSegmentSizes there directly. Otherwise we append it to the
295 // discardable attributes dictionary where it is handled by the generic
296 // Operation::create(...) method.
297 if (result.propertiesAttr) {
298 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
299 attrs.append("operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
303 result.propertiesAttr = attrs.getDictionary(parser.getContext());
304 } else {
305 result.addAttribute("operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
309 }
310 }
311 if (!result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
314 if (info) {
315 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 << "'" << result.name.getStringRef() << "' op ";
318 })))
319 return failure();
320 }
321 }
322 return success();
323}
324
326 ValueRange outputs) {
327 if (!inputs.empty())
328 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
329 if (!outputs.empty())
330 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
331}
332
333//===----------------------------------------------------------------------===//
334// Specific parsing and printing for named structured ops created by ods-gen.
335//===----------------------------------------------------------------------===//
336
338 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
339 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
340 RegionBuilderFn regionBuilder, SMLoc loc) {
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
342 return parser.emitError(
343 parser.getCurrentLocation(),
344 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
347 }
348
349 OpBuilder opBuilder(parser.getContext());
350 ParseResult result = success();
352 opBuilder, region, inputTypes, outputTypes, attrs,
353 [&]() {
354 result = failure();
355 return parser.emitError(loc);
356 },
357 regionBuilder);
358 return result;
359}
360
361static ParseResult
363 SmallVectorImpl<Type> &resultTypes) {
364 if (parser.parseOptionalArrowTypeList(resultTypes))
365 return failure();
366 return success();
367}
368
369static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
371 unsigned numRegionArgs,
372 RegionBuilderFn regionBuilder) {
373 // TODO: Enable when ods-gen supports captures.
374 SmallVector<Type, 1> inputTypes, outputTypes;
375 SMLoc loc = parser.getCurrentLocation();
376 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
377 return failure();
378
379 // Parse optional attributes.
380 if (parser.parseOptionalAttrDict(result.attributes))
381 return failure();
382
383 // TODO: consider merging results parsing into region parsing.
384 // Need to wait for declarative assembly resolution to decide.
385 SmallVector<Type, 1> outputTensorsTypes;
386 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
387 return failure();
388 result.addTypes(outputTensorsTypes);
389
390 std::unique_ptr<Region> region = std::make_unique<Region>();
391 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
392 outputTypes, result.attributes.getAttrs(),
393 regionBuilder, loc))
394 return failure();
395 result.addRegion(std::move(region));
396
397 return success();
398}
399
401 TypeRange resultTypes) {
402 if (resultTypes.empty())
403 return;
404 p.printOptionalArrowTypeList(resultTypes);
405}
406
408 ValueRange inputs, ValueRange outputs,
409 ArrayRef<StringRef> elidedAttrs = {}) {
410 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
411
412 // Printing is shared with generic ops, except for the region and
413 // attributes.
414 printCommonStructuredOpParts(p, inputs, outputs);
415
416 // Results printing.
418
419 // Region is elided.
420}
421
422//===----------------------------------------------------------------------===//
423// Region builder helper.
424// TODO: Move this to a utility library.
425// The public methods on this class are referenced directly from generated code.
426// Helper build the unary, binary, and type conversion functions defined by the
427// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
428// class.
429//
430// Implementations of the math functions must be polymorphic over numeric types,
431// internally performing necessary casts. If the function application makes no
432// sense, then the only recourse is to assert and return nullptr. This can be
433// extended later if it becomes possible to fail construction of the region. The
434// invariant should be enforced at a higher level.
435//
436// TODO: These helpers are currently type polymorphic over the class of integer
437// and floating point types, but they will not internally cast within bit
438// widths of a class (mixed precision such as i8->i32) or across classes
439// (i.e. mixed float and integer). Many such combinations are ambiguous or need
440// to be handled with care and work is being considered to extend the op
441// language to make such cases explicit. In the mean-time, violating this will
442// fail verification, which is deemed acceptable.
443//===----------------------------------------------------------------------===//
444
445namespace {
446
447class RegionBuilderHelper {
448public:
449 RegionBuilderHelper(OpBuilder &builder, Block &block)
450 : builder(builder), block(block) {}
451
452 // Build the unary functions defined by OpDSL.
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
454 function_ref<InFlightDiagnostic()> emitError = {}) {
455 if (!isFloatingPoint(arg)) {
456 if (emitError) {
457 emitError() << "unsupported non numeric type";
458 return nullptr;
459 }
460 llvm_unreachable("unsupported non numeric type");
461 }
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
464 switch (unaryFn) {
465 case UnaryFn::exp:
466 return math::ExpOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::log:
468 return math::LogOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::abs:
470 return math::AbsFOp::create(builder, arg.getLoc(), arg);
471 case UnaryFn::ceil:
472 return math::CeilOp::create(builder, arg.getLoc(), arg);
473 case UnaryFn::floor:
474 return math::FloorOp::create(builder, arg.getLoc(), arg);
475 case UnaryFn::negf:
476 return arith::NegFOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.getType());
479 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
482 }
483 case UnaryFn::round:
484 return math::RoundOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::sqrt:
486 return math::SqrtOp::create(builder, arg.getLoc(), arg);
487 case UnaryFn::rsqrt:
488 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
491 case UnaryFn::tanh:
492 return math::TanhOp::create(builder, arg.getLoc(), arg);
493 case UnaryFn::erf:
494 return math::ErfOp::create(builder, arg.getLoc(), arg);
495 }
496 if (emitError) {
497 emitError() << "unsupported unary function";
498 return nullptr;
499 }
500 llvm_unreachable("unsupported unary function");
501 }
502
503 // Build the binary functions defined by OpDSL.
504 // If emitError is provided, an error will be emitted if the operation is not
505 // supported and a nullptr will be returned, otherwise an assertion will be
506 // raised.
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
508 function_ref<InFlightDiagnostic()> emitError = {}) {
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
512 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
513 arg1.getType().getIntOrFloatBitWidth() == 1;
514 if (!allComplex && !allFloatingPoint && !allInteger) {
515 if (emitError) {
516 emitError()
517 << "Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
519 << arg0.getType() << " and " << arg1.getType();
520 return nullptr;
521 }
522 llvm_unreachable("unsupported non numeric type");
523 }
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
526 switch (binaryFn) {
527 case BinaryFn::add:
528 if (allComplex)
529 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool)
533 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
535 case BinaryFn::sub:
536 if (allComplex)
537 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
540 if (allBool) {
541 if (emitError) {
542 emitError() << "unsupported operation: sub with bools";
543 return nullptr;
544 }
545 llvm_unreachable("unsupported operation: sub with bools");
546 }
547 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::mul:
549 if (allComplex)
550 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool)
554 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
556 case BinaryFn::div:
557 if (allComplex)
558 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
561 if (allBool) {
562 if (emitError) {
563 emitError() << "unsupported operation: div with bools";
564 return nullptr;
565 }
566 llvm_unreachable("unsupported operation: div with bools");
567 }
568 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
571 if (emitError) {
572 emitError() << "unsupported operation: unsigned div not on uint";
573 return nullptr;
574 }
575 llvm_unreachable("unsupported operation: unsigned div not on uint");
576 }
577 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
579 assert(!allComplex);
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
584 assert(!allComplex);
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
589 assert(!allComplex);
590 if (!allInteger || allBool) {
591 if (emitError) {
592 emitError() << "unsupported operation: unsigned max not on uint";
593 return nullptr;
594 }
595 llvm_unreachable("unsupported operation: unsigned max not on uint");
596 }
597 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
599 assert(!allComplex);
600 if (!allInteger || allBool) {
601 if (emitError) {
602 emitError() << "unsupported operation: unsigned min not on uint";
603 return nullptr;
604 }
605 llvm_unreachable("unsupported operation: unsigned min not on uint");
606 }
607 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
608 case BinaryFn::powf:
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
611 }
612 if (emitError) {
613 emitError() << "unsupported binary function";
614 return nullptr;
615 }
616 llvm_unreachable("unsupported binary function");
617 }
618
619 // Build the ternary functions defined by OpDSL.
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
621 function_ref<InFlightDiagnostic()> emitError = {}) {
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
624 switch (ternaryFn) {
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
627 }
628 if (emitError) {
629 emitError() << "unsupported ternary function";
630 return nullptr;
631 }
632 llvm_unreachable("unsupported ternary function");
633 }
634
635 // Build the type functions defined by OpDSL.
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
637 function_ref<InFlightDiagnostic()> emitError = {}) {
638 switch (typeFn) {
639 case TypeFn::cast_signed:
640 return cast(toType, operand, false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand, true);
643 }
644 if (emitError) {
645 emitError() << "unsupported type conversion function";
646 return nullptr;
647 }
648 llvm_unreachable("unsupported type conversion function");
649 }
650
651 void yieldOutputs(ValueRange values) {
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
656 }
657
658 Value constant(const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr = parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
665 }
666
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
671 }
672
673 Type getIntegerType(unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
675 }
676
677 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
679
680private:
681 // Generates operations to cast the given operand to a specified type.
682 // If the cast cannot be performed, a warning will be issued and the
683 // operand returned as-is (which will presumably yield a verification
684 // issue downstream).
685 Value cast(Type toType, Value operand, bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.getLoc();
689 if (isa<UnknownLoc>(loc)) {
690 if (operand.getDefiningOp())
691 loc = operand.getDefiningOp()->getLoc();
692 else if (operand.getParentBlock() &&
693 operand.getParentBlock()->getParentOp())
694 loc = operand.getParentBlock()->getParentOp()->getLoc();
695 }
696 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
697 }
698
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.getType());
701 }
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.getType());
704 }
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.getType());
707 }
708
709 OpBuilder &builder;
710 Block &block;
711};
712
713} // namespace
714
715//===----------------------------------------------------------------------===//
716// CopyOp
717//===----------------------------------------------------------------------===//
718
719namespace {
720
721struct EraseSelfCopy : OpRewritePattern<CopyOp> {
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter) const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
726 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
727 if (copyOp.hasPureBufferSemantics())
728 rewriter.eraseOp(copyOp);
729 else
730 rewriter.replaceOp(copyOp, copyOp.getInputs());
731
732 return success();
733 }
734};
735
736} // namespace
737
738void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.add<EraseSelfCopy>(context);
741}
742
743//===----------------------------------------------------------------------===//
744// FillOp
745//===----------------------------------------------------------------------===//
746
747namespace {
748
749/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
750///
751/// For such op chains, we can create new linalg.fill ops with the result
752/// type of the tensor.expand/collapse_shape op.
753template <typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter) const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
759 if (!oldFill)
760 return failure();
761
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
765
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
770 } else {
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
774 }
775 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
776 ValueRange{newInit});
777 return success();
778 }
779};
780
781/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
782/// filling value are the same.
783struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
785
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter) const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
789 if (!fillOp)
790 return failure();
791
792 // We can only fold if the padding value is the same as the original
793 // filling value.
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
796 return failure();
797
798 ReifiedRankedShapedTypeDims reifiedShape;
799 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
800 return rewriter.notifyMatchFailure(
801 padOp, "failed to reify tensor.pad op result shape");
802
803 auto emptyTensor =
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
806 Value replacement =
807 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
808 ValueRange{emptyTensor})
809 .getResult(0);
810 if (replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
812 padOp.getResultType(), replacement);
813 }
814 rewriter.replaceOp(padOp, replacement);
815 return success();
816 }
817};
818
819/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
820/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
821/// filling value are the same.
822struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
824
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter) const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
828 if (!srcPadOp)
829 return failure();
830
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
832 return failure();
833
834 // Walk back the tensor.insert_slice chain and find the first destination
835 // value at the start of the chain.
836 Value firstDest = insertOp.getDest();
837 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
839 return failure();
840
841 // Make sure the range of values accessed are disjoint. Without this, we
842 // cannot fold tensor.pad away.
843 bool disjoint = false;
844 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
845 // If the dimension has dynamic offset/size, we cannot guarantee
846 // disjoint. So just skip it.
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
850 continue;
851
852 // Get the range start and end, inclusively for both.
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
860 disjoint = true;
861 break;
862 }
863 }
864
865 if (!disjoint)
866 break;
867 firstDest = prevOp.getDest();
868 }
869
870 // Check whether the first destination is a fill op. For overlapped cases,
871 // this also cannot be true.
872 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
873 if (!dstFillOp)
874 return failure();
875
876 // We can only fold if the padding value is the same as the original
877 // filling value.
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
880 return failure();
881
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
884
885 Location loc = insertOp.getLoc();
886 MLIRContext *context = getContext();
887
888 AffineExpr sym0, sym1;
889 bindSymbols(context, sym0, sym1);
890 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
891
892 // Calculate the new offsets for the insert. It should be the old offsets
893 // plus low padding sizes.
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
898 }
899
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
904 newSizes.push_back(
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
906 .getResult());
907 } else {
908 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
909 }
910 }
911
912 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
915 return success();
916 }
917};
918
919/// Fold tensor.extract(linalg.fill(<input>)) into <input>
920struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
921public:
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
923
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter) const override {
926 // See if tensor input of tensor.extract op is the result of a linalg.fill
927 // op.
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
929 if (!fillOp)
930 return failure();
931
932 // Get scalar input operand of linalg.fill op.
933 Value extractedScalar = fillOp.getInputs()[0];
934
935 // Replace tensor.extract op with scalar value used to fill the tensor.
936 rewriter.replaceOp(extractOp, extractedScalar);
937 return success();
938 }
939};
940
941/// Folds pack(fill) into a single fill op if
942/// 1. The pack op does not have padding value, or
943/// 2. The filled value and padding value are the same.
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
947 if (!fillOp)
948 return failure();
949
950 if (auto paddingValue = packOp.getPaddingValue())
951 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
952 return failure();
953
954 Value packOpDest = packOp.getDest();
955 if (!packOpDest.hasOneUse())
956 return failure();
957
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
959 packOp.getDest());
960}
961
962/// Wrapper pattern that applies foldFillPackIntoFillOp method.
963struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
964public:
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
967
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter) const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
971 if (failed(fillOp))
972 return failure();
973 rewriter.replaceOp(packOp, fillOp.value().result());
974 return success();
975 }
976};
977
978/// Fold fill with copy.
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
981
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter) const override {
984 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
985 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
986 fillOp.getInputs(),
987 copyOp.getOutputs());
988 return success();
989 }
990 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
991 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
992 fillOp.getOutputs());
993 return success();
994 }
995 return failure();
996 }
997};
998
999/// Fold fill with transpose.
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter) const override {
1005 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1006 rewriter.replaceOpWithNewOp<FillOp>(
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1009 return success();
1010 }
1011 return failure();
1012 }
1013};
1014
1015/// Fold a concat with all elements being fills of the same value
1016/// into a fill of the concat result shape.
1017struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1019
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter) const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1024 return failure();
1025 }
1026
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1028 if (!firstFillOp) {
1029 return failure();
1030 }
1031 // Prefetch the fill value.
1032 OpFoldResult firstFillVal =
1033 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1034 // Collect all the outs values for the fill operations.
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1037
1038 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1040 if (!fillOp) {
1041 return false;
1042 }
1043
1044 OpFoldResult fillVal =
1045 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1046 if (fillVal != firstFillVal)
1047 return false;
1048
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1050 return true;
1051 };
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1054 return rewriter.notifyMatchFailure(
1055 concatOp, "not all operands are defined by a compatible fill op");
1056 }
1057
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1060 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1061 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1062 return success();
1063 }
1064};
1065
1066} // namespace
1067
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// GenericOp
1079//===----------------------------------------------------------------------===//
1080
1082 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1083 ValueRange outputs,
1084 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1085 SmallVector<Type, 4> blockArgTypes;
1086 SmallVector<Location, 4> blockArgLocs;
1087 for (ValueRange container : {inputs, outputs}) {
1088 for (Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1091 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1092 blockArgLocs.push_back(v.getLoc());
1093 }
1094 }
1095
1096 OpBuilder::InsertionGuard guard(builder);
1097 Block *bodyBlock =
1098 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1099 bodyBuild(builder, loc, bodyBlock->getArguments());
1100}
1101
1102void GenericOp::getAsmBlockArgumentNames(Region &region,
1103 OpAsmSetValueNameFn setNameFn) {
1104 for (Value v : getRegionInputArgs())
1105 setNameFn(v, "in");
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v, "out");
1108}
1109
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1112 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1114 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1119 if (bodyBuild)
1120 buildGenericRegion(builder, result.location, *result.regions.front(),
1121 inputs, outputs, bodyBuild);
1122}
1123
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1126 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1129 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder, result, resultTensorTypes, inputs, outputs,
1132 builder.getAffineMapArrayAttr(indexingMaps),
1133 builder.getArrayAttr(llvm::map_to_vector(
1134 iteratorTypes,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1137 })),
1138 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1139 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1148 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1149 ArrayRef<NamedAttribute> attributes) {
1150 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, ValueRange inputs,
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1161 /*doc=*/"",
1162 /*libraryCall=*/"", bodyBuild, attributes);
1163}
1164
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1167 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1169 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1172 iteratorTypes,
1173 /*doc=*/"",
1174 /*libraryCall=*/"", bodyBuild, attributes);
1175}
1176
1177void GenericOp::print(OpAsmPrinter &p) {
1178 p << " ";
1179
1180 // Print extra attributes.
1181 auto genericAttrNames = linalgTraitAttrNames();
1182
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1191 // Convert IteratorType enums into the string representation. This is
1192 // needed, because tests still use the old format when 'iterator_types'
1193 // attribute is represented as an array of strings.
1194 // TODO: Remove this conversion once tests are fixed.
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(getContext(), stringifyIteratorType(t));
1198 });
1199
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(getContext(), iteratorTypeNames));
1203 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1205 }
1206 }
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1209 p << genericDictAttr;
1210 }
1211
1212 // Printing is shared with named ops, except for the region and attributes
1213 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1214
1215 genericAttrNames.push_back("operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1217
1218 bool hasExtraAttrs = false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1221 break;
1222 }
1223 if (hasExtraAttrs) {
1224 p << " attrs = ";
1225 p.printOptionalAttrDict((*this)->getAttrs(),
1226 /*elidedAttrs=*/genericAttrNames);
1227 }
1228
1229 // Print region.
1230 if (!getRegion().empty()) {
1231 p << ' ';
1232 p.printRegion(getRegion());
1233 }
1234
1235 // Print results.
1236 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1237}
1238
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1240 DictionaryAttr dictAttr;
1241 // Parse the core linalg traits that must check into a dictAttr.
1242 // The name is unimportant as we will overwrite result.attributes.
1243 // The core linalg traits must contain the information necessary to pass the
1244 // verifier.
1245 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1246 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1247 return failure();
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1250
1251 // Convert array of string into an array of IteratorType enums. This is
1252 // needed, because tests still use the old format when 'iterator_types'
1253 // attribute is represented as an array of strings.
1254 // TODO: Remove this conversion once tests are fixed.
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(result.name)));
1257 if (!iteratorTypes) {
1258 return parser.emitError(attributeLocation)
1259 << "expected " << getIteratorTypesAttrName(result.name)
1260 << " array attribute";
1261 }
1262
1263 SmallVector<Attribute> iteratorTypeAttrs;
1264
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1268 return parser.emitError(parser.getCurrentLocation())
1269 << "unexpected iterator_type (" << s << ")";
1270
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1273 }
1274 result.attributes.set(getIteratorTypesAttrName(result.name),
1275 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1276
1277 // Parsing is shared with named ops, except for the region.
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1279 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1280 return failure();
1281
1282 // Optional attributes may be added.
1283 if (succeeded(parser.parseOptionalKeyword("attrs")))
1284 if (failed(parser.parseEqual()) ||
1285 failed(parser.parseOptionalAttrDict(result.attributes)))
1286 return failure();
1287
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 if (parser.parseRegion(*region, {}))
1290 return failure();
1291 result.addRegion(std::move(region));
1292
1293 // Generic ops may specify that a subset of its outputs are tensors. Such
1294 // outputs are specified in the result type.
1295 // TODO: may need to move output parsing before region parsing.
1296 // Need to wait for declarative assembly resolution to decide.
1297 SmallVector<Type, 1> outputTensorsTypes;
1298 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1299 return failure();
1300 result.addTypes(outputTensorsTypes);
1301
1302 return success();
1303}
1304
1307 &effects,
1308 LinalgOp linalgOp) {
1309 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.getType()))
1311 continue;
1312 effects.emplace_back(
1313 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1314 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1315 }
1316
1317 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().getType()))
1319 continue;
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1322 /*effectOnFullRegion=*/true,
1324 }
1325 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1326 /*effectOnFullRegion=*/true,
1328 }
1329}
1330
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1333 &effects) {
1334 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1335}
1336
1339 // Operands with value semantics are speculatable, while operands with memory
1340 // semantics are not.
1341 if (!linalgOp.hasPureTensorSemantics())
1343 // The body of the op can still have speculation in its region.
1345}
1346
1347Speculation::Speculatability GenericOp::getSpeculatability() {
1348 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1349}
1350
1351namespace {
1352
1353/// Remove linalg operations that are just copying the values from inputs to
1354/// results. In the memref case, the operation must be copying to and from the
1355/// same value. Requirements are:
1356/// 1) All iterator types are parallel
1357/// 2) The body contains just a yield operation with the yielded values being
1358/// the arguments corresponding to the operands.
1359template <typename OpTy>
1360struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1362
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter) const override {
1365 // All indexing maps must be equal. It follows that they are permutations.
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1367 return failure();
1368
1369 // Check that the body of the linalg operation is just a linalg.yield
1370 // operation.
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1373 return failure();
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1375 if (!yieldOp)
1376 return failure();
1377
1378 // In the buffer case, we need to check exact buffer equality.
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1383 return rewriter.notifyMatchFailure(
1384 linalgOp, "expected single input and output to be the same value");
1385 }
1386
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1389 return rewriter.notifyMatchFailure(linalgOp,
1390 "cannot fold fill-like op");
1391 }
1392
1393 rewriter.eraseOp(linalgOp);
1394 return success();
1395 }
1396
1397 if (!linalgOp.hasPureTensorSemantics()) {
1398 return rewriter.notifyMatchFailure(
1399 linalgOp, "mixed semantics is not supported yet");
1400 }
1401
1402 // Get the argument number of the returned values. That is the operand
1403 // number to use for replacing uses of this operation.
1404 SmallVector<Value> returnedArgs;
1405 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1408 return failure();
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1412 // The input can have a different type than the result, e.g. a dynamic
1413 // input dimension can be turned into a static output dimension.
1414 Type returnType = returnedArg.getType();
1415 if (returnType != resultType) {
1416 // Distinguish between sparse conversion or dense tensor casting.
1417 // TODO: unify the two ops?
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1422 else {
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1424 resultType))
1425 return failure();
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1428 }
1429 }
1430 returnedArgs.push_back(returnedArg);
1431 }
1432
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1434 return failure();
1435 rewriter.replaceOp(linalgOp, returnedArgs);
1436 return success();
1437 }
1438};
1439
1440} // namespace
1441
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1445}
1446
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1448 return memref::foldMemRefCast(*this);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// MapOp
1453//===----------------------------------------------------------------------===//
1454
1455static ParseResult parseDstStyleOp(
1457 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1458 nullptr) {
1459 // Parse `ins` and `outs`.
1460 SmallVector<Type, 4> inputTypes, outputTypes;
1461 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1462 /*addOperandSegmentSizes=*/false))
1463 return failure();
1464
1465 // Add result types.
1466 for (Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1469 }
1470
1471 // Parse required attributes.
1472 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1473 return failure();
1474
1475 // Parse optional attributes.
1476 if (parser.parseOptionalAttrDict(result.attributes))
1477 return failure();
1478 return success();
1479}
1480
1481void MapOp::getAsmBlockArgumentNames(Region &region,
1482 OpAsmSetValueNameFn setNameFn) {
1483 for (Value v : getRegionInputArgs())
1484 setNameFn(v, "in");
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v, "init");
1487}
1488
1489void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(), "mapped");
1492}
1493
1494void MapOp::build(
1495 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1496 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1497 ArrayRef<NamedAttribute> attributes) {
1498 build(builder, result, TypeRange{}, inputs, init);
1499 result.addAttributes(attributes);
1500
1501 // Add output types for `RankedTensorType` output arguments.
1502 Type initType = init.getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1505
1506 if (bodyBuild)
1507 buildGenericRegion(builder, result.location, *result.regions.front(),
1508 inputs, /*outputs=*/{init}, bodyBuild);
1509}
1510
1512 const OperationName &payloadOpName,
1513 const NamedAttrList &payloadOpAttrs,
1514 ArrayRef<Value> operands,
1515 bool initFirst = false, bool mapInit = true) {
1516 OpBuilder b(parser.getContext());
1517 Region *body = result.addRegion();
1518 Block &block = body->emplaceBlock();
1519 b.setInsertionPointToStart(&block);
1520 for (auto &operand : operands) {
1521 block.addArgument(
1522 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1523 b.getUnknownLoc());
1524 }
1525 SmallVector<Value> payloadOpOperands;
1526 // If initFirst flag is enabled, we consider init as the first position of
1527 // payload operands.
1528 if (initFirst) {
1529 if (mapInit)
1530 payloadOpOperands.push_back(block.getArguments().back());
1531 for (const auto &arg : block.getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1533 } else {
1534 payloadOpOperands = {block.getArguments().begin(),
1535 block.getArguments().end() - int(!mapInit)};
1536 }
1537
1538 Operation *payloadOp = b.create(
1539 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1540 payloadOpOperands,
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1542 .getElementType()},
1543 payloadOpAttrs);
1544 YieldOp::create(b, result.location, payloadOp->getResults());
1545}
1546
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1550 if (succeeded(parser.parseOptionalLBrace())) {
1551 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1552 if (failed(operationName))
1553 return failure();
1554 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1555 return failure();
1556 payloadOpName = operationName.value();
1557 if (parser.parseRBrace())
1558 return failure();
1559 }
1560
1561 if (parseDstStyleOp(parser, result))
1562 return failure();
1563
1564 if (payloadOpName.has_value()) {
1565 if (!result.operands.empty())
1566 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1567 payloadOpAttrs, ArrayRef(result.operands), false,
1568 false);
1569 else
1570 result.addRegion();
1571 } else {
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1573 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1574 /*allowType=*/true, /*allowAttrs=*/true)) {
1575 return failure();
1576 }
1577 Region *body = result.addRegion();
1578 if (parser.parseRegion(*body, regionArgs))
1579 return failure();
1580 }
1581 return success();
1582}
1583
1584static bool canUseShortForm(Block *body, bool initFirst = false,
1585 bool mapInit = true) {
1586 // `intFirst == true` implies that we want to map init arg
1587 if (initFirst && !mapInit)
1588 return false;
1589 // Check if the body can be printed in short form. The following 4 conditions
1590 // must be satisfied:
1591
1592 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1593 if (body->getOperations().size() != 2)
1594 return false;
1595 Operation &payload = body->getOperations().front();
1596
1597 // 2) The payload op must have the same number of operands as the number of
1598 // block arguments.
1599 if (payload.getNumOperands() == 0 ||
1600 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1601 return false;
1602
1603 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1604 // must be the first operand of the payload op, otherwise, the operands
1605 // must match the block arguments in order.
1606 if (initFirst) {
1607 // check init
1608 if (payload.getOperands().back() != body->getArgument(0))
1609 return false;
1610 // check rest
1611 for (const auto &[operand, bbArg] :
1612 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1613 if (bbArg != operand)
1614 return false;
1615 }
1616 } else {
1617 for (const auto &[operand, bbArg] :
1618 llvm::zip(payload.getOperands(),
1619 body->getArguments().drop_back(int(!mapInit)))) {
1620 if (bbArg != operand)
1621 return false;
1622 }
1623 }
1624
1625 // 4) The `yield` operand must be the result of the payload op.
1626 auto yieldOp = cast<YieldOp>(body->getTerminator());
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1630}
1631
1632static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1633 SmallVector<StringRef> elidedAttrs;
1634 std::string attrToElide;
1635 p << " { " << payloadOp->getName().getStringRef();
1636 for (const auto &attr : payloadOp->getAttrs()) {
1637 auto fastAttr =
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1642 break;
1643 }
1644 }
1645 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1646 p << " }";
1647}
1648
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1651 bool useShortForm =
1652 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1653 if (useShortForm) {
1654 printShortForm(p, &mapper->getOperations().front());
1655 }
1656
1657 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1658 p.printOptionalAttrDict((*this)->getAttrs());
1659
1660 if (!useShortForm) {
1661 // Print region if the payload op was not detected.
1662 p.increaseIndent();
1663 p.printNewline();
1664 p << "(";
1665 llvm::interleaveComma(mapper->getArguments(), p,
1666 [&](auto arg) { p.printRegionArgument(arg); });
1667 p << ") ";
1668
1669 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1670 p.decreaseIndent();
1671 }
1672}
1673
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1677
1678 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1679 // region.
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() << "expects number of operands to match the arity of "
1682 "mapper, but got: "
1683 << getInputs().size() + 1 << " and "
1684 << blockArgs.size();
1685
1686 // The parameters of mapper should all match the element type of inputs.
1687 for (const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() << "expected element type of input " << inputElemType
1693 << " to match bbArg type " << bbArgType;
1694 }
1695 }
1696
1697 // The shape of each input must match the shape of the output.
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType : TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() << "expected shape of input (" << inputElemShape
1703 << ") to match shape of output (" << outputShape
1704 << ")";
1705 }
1706 }
1707
1708 return success();
1709}
1710
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1714}
1715
1716ArrayAttr MapOp::getIndexingMaps() {
1717 Builder builder(getContext());
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1720 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1721 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1722}
1723
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1726 &effects) {
1727 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1728}
1729
1730Speculation::Speculatability MapOp::getSpeculatability() {
1731 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// ReduceOp
1736//===----------------------------------------------------------------------===//
1737
1738void ReduceOp::getAsmBlockArgumentNames(Region &region,
1739 OpAsmSetValueNameFn setNameFn) {
1740 for (Value v : getRegionInputArgs())
1741 setNameFn(v, "in");
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v, "init");
1744}
1745
1746void ReduceOp::getAsmResultNames(
1747 function_ref<void(Value, StringRef)> setNameFn) {
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(), "reduced");
1750}
1751
1752void ReduceOp::build(
1753 OpBuilder &builder, OperationState &result, ValueRange inputs,
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1755 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1756 ArrayRef<NamedAttribute> attributes) {
1757 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1758 result.addAttributes(attributes);
1759
1760 // Add output types for `RankedTensorType` output arguments.
1761 for (Value init : inits) {
1762 Type initType = init.getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1765 }
1766
1767 if (bodyBuild)
1768 buildGenericRegion(builder, result.location, *result.regions.front(),
1769 inputs, inits, bodyBuild);
1770}
1771
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1773 int64_t inputRank =
1774 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1780}
1781
1782ArrayAttr ReduceOp::getIndexingMaps() {
1783 int64_t inputRank =
1784 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1786 getNumDpsInputs(),
1788 AffineMap resultMap =
1790 .dropResults(getDimensions());
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1794}
1795
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1798 &effects) {
1799 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1800}
1801
1802Speculation::Speculatability ReduceOp::getSpeculatability() {
1803 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1804}
1805
1806static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1807 NamedAttrList &attributes,
1808 StringRef attributeName) {
1809 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1810 return failure();
1811
1812 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1813 return success();
1814}
1815
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1819 if (succeeded(parser.parseOptionalLBrace())) {
1820 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1821 if (failed(operationName))
1822 return failure();
1823 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1824 return failure();
1825 payloadOpName = operationName.value();
1826 if (parser.parseRBrace())
1827 return failure();
1828 }
1829
1830 if (parseDstStyleOp(
1831 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1832 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1833 }))
1834 return failure();
1835
1836 if (payloadOpName.has_value()) {
1837 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1838 ArrayRef(result.operands), /*initFirst=*/true);
1839 } else {
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1841 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1842 /*allowType=*/true, /*allowAttrs=*/true)) {
1843 return failure();
1844 }
1845
1846 Region *body = result.addRegion();
1847 if (parser.parseRegion(*body, regionArgs))
1848 return failure();
1849 }
1850
1851 return success();
1852}
1853
1854static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1855 ArrayRef<int64_t> attributeValue) {
1856 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1857}
1858
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1861 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1862 if (useShortForm) {
1863 printShortForm(p, &mapper->getOperations().front());
1864 }
1865
1866 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1867 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1868 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1869 if (!useShortForm) {
1870 // Print region if the payload op was not detected.
1871 p.increaseIndent();
1872 p.printNewline();
1873 p << "(";
1874 llvm::interleaveComma(mapper->getArguments(), p,
1875 [&](auto arg) { p.printRegionArgument(arg); });
1876 p << ") ";
1877
1878 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1879 p.decreaseIndent();
1880 }
1881}
1882
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1885
1886 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1887 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1888 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1889 return emitOpError() << "expects all inputs to have the same shapes. "
1890 "Shape at input-index "
1891 << i
1892 << " is not equal to the shape at input-index 0.";
1893 }
1894 }
1895 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1896 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1897 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1898 return emitOpError() << "expects all outputs to have the same shapes. "
1899 "Shape at output-index "
1900 << i
1901 << " is not equal to the shape at output-index 0.";
1902 }
1903 }
1904 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1905 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1906
1907 DenseSet<int64_t> dimensionsToReduce;
1908 for (int64_t dimension : dimensionsRef) {
1909 if (dimension < 0 || dimension >= inputType.getRank()) {
1910 return emitOpError()
1911 << "dimensions for reduction should be in the range [0, "
1912 << inputType.getRank() - 1 << "].";
1913 }
1914 dimensionsToReduce.insert(dimension);
1915 }
1916
1917 auto inputDims = inputType.getShape();
1918 auto initDims = initType.getShape();
1919
1920 // Input dimensions that will be left after the reduction.
1921 SmallVector<int64_t> reducedInputDims;
1922 for (const auto &en : llvm::enumerate(inputDims)) {
1923 if (!dimensionsToReduce.count(en.index()))
1924 reducedInputDims.push_back(en.value());
1925 }
1926
1927 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1928 return emitOpError() << "number of dimensions after reduction "
1929 << reducedInputDims.size()
1930 << " doesn't match the init rank "
1931 << initType.getRank();
1932 }
1933
1934 if (reducedInputDims != initDims)
1935 return emitOpError() << "init dimensions [" << initDims
1936 << "] doesn't match input dimensions after reduction ["
1937 << reducedInputDims << "]";
1938
1939 Block *block = getBody();
1940 if (block->getNumArguments() != this->getNumOperands())
1941 return emitOpError()
1942 << "mismatching number of operands and block arguments";
1943
1944 // Check that the first block arguments match the element type of the inputs.
1945 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1946 Type inputElementType =
1947 llvm::cast<ShapedType>(input.getType()).getElementType();
1948 if (inputElementType != bbArg.getType())
1949 return emitOpError()
1950 << "input element type " << inputElementType
1951 << " does not match corresponding block argument type "
1952 << bbArg.getType();
1953 }
1954
1955 // Check that the last block arguments match the element type of the outputs.
1956 for (auto [output, bbArg] : llvm::zip(
1957 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1958 auto outputElementType =
1959 llvm::cast<ShapedType>(output.getType()).getElementType();
1960 if (outputElementType != bbArg.getType())
1961 return emitOpError()
1962 << "output element type " << outputElementType
1963 << " does not match corresponding block argument type "
1964 << bbArg.getType();
1965 }
1966 return success();
1967}
1968
1969//===----------------------------------------------------------------------===//
1970// TransposeOp
1971//===----------------------------------------------------------------------===//
1972
1973static void buildIdentityRegion(OpBuilder &builder, Location loc,
1974 Region &region, ValueRange inputs,
1975 ValueRange outputs) {
1976 buildGenericRegion(builder, loc, region, inputs, outputs,
1977 [](OpBuilder &b, Location loc, ValueRange args) {
1978 if (!args.empty())
1979 linalg::YieldOp::create(b, loc, args[0]);
1980 });
1981}
1982
1983void TransposeOp::build(::mlir::OpBuilder &builder,
1984 ::mlir::OperationState &result, Value input, Value init,
1985 DenseI64ArrayAttr permutation,
1986 ArrayRef<NamedAttribute> attributes) {
1987 result.addOperands(input);
1988 result.addOperands(init);
1989 result.addAttribute(getPermutationAttrName(result.name), permutation);
1990 result.addAttributes(attributes);
1991
1992 // Add output types for `RankedTensorType` output arguments.
1993 Type initType = init.getType();
1994 if (llvm::isa<RankedTensorType>(initType))
1995 result.addTypes(initType);
1996
1997 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1998 init);
1999}
2000
2001void TransposeOp::build(::mlir::OpBuilder &builder,
2002 ::mlir::OperationState &result, Value input, Value init,
2003 ArrayRef<int64_t> permutation,
2004 ArrayRef<NamedAttribute> attributes) {
2005 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2006 attributes);
2007}
2008
2009ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2011 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2012 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2013 })))
2014 return failure();
2015
2016 OpBuilder builder(parser.getContext());
2017 buildIdentityRegion(builder, result.location, *result.addRegion(),
2018 /*inputs=*/result.operands,
2019 /*outputs=*/{});
2020 return success();
2021}
2022
2023void TransposeOp::getAsmResultNames(
2024 function_ref<void(Value, StringRef)> setNameFn) {
2025 if (!getResults().empty())
2026 setNameFn(getResults().front(), "transposed");
2027}
2028
2029void TransposeOp::print(OpAsmPrinter &p) {
2030 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2031 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2032 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2033}
2034
2035LogicalResult TransposeOp::verify() {
2036 ArrayRef<int64_t> permutationRef = getPermutation();
2037
2038 if (!isPermutationVector(permutationRef))
2039 return emitOpError("permutation is not valid");
2040
2041 auto inputType = getInput().getType();
2042 auto initType = getInit().getType();
2043
2044 int64_t rank = inputType.getRank();
2045
2046 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2047 "init")))
2048 return failure();
2049
2050 if (rank != static_cast<int64_t>(permutationRef.size()))
2051 return emitOpError() << "size of permutation " << permutationRef.size()
2052 << " does not match the argument rank " << rank;
2053
2054 auto inputDims = inputType.getShape();
2055 auto initDims = initType.getShape();
2056
2057 for (int64_t i = 0; i < rank; ++i) {
2058 int64_t inputDim = inputDims[permutationRef[i]];
2059 int64_t initDim = initDims[i];
2060
2061 if (inputDim != initDim) {
2062 return emitOpError() << "dim(result, " << i << ") = " << initDim
2063 << " doesn't match dim(input, permutation[" << i
2064 << "]) = " << inputDim;
2065 }
2066 }
2067
2068 return success();
2069}
2070
2071SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2072 int64_t rank = getInit().getType().getRank();
2073 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2074}
2075
2076ArrayAttr TransposeOp::getIndexingMaps() {
2077 Builder builder(getContext());
2078 int64_t rank = getInit().getType().getRank();
2079 return builder.getAffineMapArrayAttr(
2081 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2082 builder.getMultiDimIdentityMap(rank)});
2083}
2084
2085void TransposeOp::getEffects(
2086 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2087 &effects) {
2088 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2089}
2090
2091Speculation::Speculatability TransposeOp::getSpeculatability() {
2092 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2093}
2094
2095LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2096 SmallVectorImpl<OpFoldResult> &result) {
2097 // Only the tensor type is supported.
2098 if (!isa<TensorType>(getInput().getType()))
2099 return failure();
2100
2101 // Single dimension transpose.
2102 if (getPermutation().empty()) {
2103 result.push_back(getInput());
2104 return success();
2105 }
2106 // Identity permutation.
2107 if (isIdentityPermutation(getPermutation())) {
2108 result.push_back(getInput());
2109 return success();
2110 }
2111
2112 return failure();
2113}
2114
2115/// Fold transpose with transpose.
2116struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2117 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2118
2119 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2120 PatternRewriter &rewriter) const override {
2121 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2122 if (!defTransposeOp)
2123 return failure();
2124 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2125 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2126 SmallVector<int64_t> foldedPerms;
2127 foldedPerms.reserve(perms.size());
2128 for (int64_t perm : perms)
2129 foldedPerms.push_back(defPerms[perm]);
2130
2131 rewriter.replaceOpWithNewOp<TransposeOp>(
2132 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2133 foldedPerms);
2134 return success();
2135 }
2136};
2137
2138/// This pattern canonicalize transpose by swapping the order of
2139/// broadcast and transpose:
2140/// transpose(broadcast(input)) -> broadcast(transpose(input))
2141struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2142 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2143
2144 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2145 PatternRewriter &rewriter) const override {
2146 Value input = transposeOp.getInput();
2147 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2148 if (!input.hasOneUse() || !broadcastOp)
2149 return failure();
2150
2151 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2152 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2153
2154 // Get new perms and new dimensions.
2155 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2157 SmallVector<int64_t> resultDimensions;
2158 unsigned dimensionSize = dimensions.size();
2159 for (unsigned i = 0; i < dimensionSize; ++i)
2160 resultDimensions.push_back(invertPerm[dimensions[i]]);
2161
2162 // Create transpose result.
2163 Value broadcastInput = broadcastOp.getInput();
2164 Location loc = transposeOp.getLoc();
2165 MLIRContext *ctx = transposeOp.getContext();
2167 auto broadcastInputTy =
2168 mlir::cast<RankedTensorType>(broadcastInput.getType());
2169 unsigned inputRank = broadcastInputTy.getRank();
2170 for (unsigned i = 0; i < inputRank; ++i) {
2171 if (broadcastInputTy.isDynamicDim(i)) {
2172 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2173 ->getResult(0));
2174 } else {
2175 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2176 broadcastInputTy.getDimSize(i)));
2177 }
2178 }
2179 SmallVector<OpFoldResult> transposeResultShapes =
2180 applyPermutation(dims, resultPerms);
2181 Value transposeInit = tensor::EmptyOp::create(
2182 rewriter, transposeOp.getLoc(), transposeResultShapes,
2183 broadcastInputTy.getElementType());
2184
2185 // Create broadcast(transpose(input)).
2186 Value transposeResult =
2187 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2188 transposeInit, resultPerms)
2189 ->getResult(0);
2190 rewriter.replaceOpWithNewOp<BroadcastOp>(
2191 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2192 return success();
2193 }
2194};
2195
2196void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2197 MLIRContext *context) {
2198 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2199}
2200
2201//===----------------------------------------------------------------------===//
2202// BroadcastOp
2203//===----------------------------------------------------------------------===//
2204
2205void BroadcastOp::build(::mlir::OpBuilder &builder,
2206 ::mlir::OperationState &result, Value input, Value init,
2207 DenseI64ArrayAttr dimensions,
2208 ArrayRef<NamedAttribute> attributes) {
2209 result.addOperands(input);
2210 result.addOperands(init);
2211 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2212 result.addAttributes(attributes);
2213
2214 // Add output types for `RankedTensorType` output arguments.
2215 Type initType = init.getType();
2216 if (llvm::isa<RankedTensorType>(initType))
2217 result.addTypes(initType);
2218
2219 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2220 init);
2221}
2222
2223void BroadcastOp::build(::mlir::OpBuilder &builder,
2224 ::mlir::OperationState &result, Value input, Value init,
2225 ArrayRef<int64_t> dimensions,
2226 ArrayRef<NamedAttribute> attributes) {
2227 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2228 attributes);
2229}
2230
2231ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2233 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2234 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2235 })))
2236 return failure();
2237
2238 OpBuilder builder(parser.getContext());
2239 buildIdentityRegion(builder, result.location, *result.addRegion(),
2240 /*inputs=*/result.operands,
2241 /*outputs=*/{});
2242 return success();
2243}
2244
2245void BroadcastOp::getAsmResultNames(
2246 function_ref<void(Value, StringRef)> setNameFn) {
2247 if (!getResults().empty())
2248 setNameFn(getResults().front(), "broadcasted");
2249}
2250
2251void BroadcastOp::print(OpAsmPrinter &p) {
2252 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2253 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2254 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2255}
2256
2257LogicalResult BroadcastOp::verify() {
2258 ArrayRef<int64_t> dimensionsRef = getDimensions();
2259
2260 auto inputType = getInput().getType();
2261 auto initType = getInit().getType();
2262
2263 int64_t inputRank = inputType.getRank();
2264 int64_t initRank = initType.getRank();
2265
2266 auto inputShape = inputType.getShape();
2267 auto initShape = initType.getShape();
2268
2269 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2270 return emitOpError() << "input rank plus added dimensions does not "
2271 "match init rank. input rank: "
2272 << inputRank
2273 << ", dimensions size: " << dimensionsRef.size()
2274 << ", init rank: " << initRank;
2275
2276 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2277 if (dim < 0 || dim >= initRank)
2278 return emitOpError() << "dimension " << idx
2279 << " is out of range. expected range: [0, "
2280 << initRank - 1 << "], got: " << dim;
2281 }
2282
2283 // Mapping from input dims to init dims.
2284 SmallVector<int64_t> dimMap;
2285 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2286 if (!llvm::is_contained(dimensionsRef, dim))
2287 dimMap.push_back(dim);
2288 }
2289
2290 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2291 // This dimensions is mapped from the input. Init and input dims should
2292 // match.
2293 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2294 return emitOpError() << "input dim " << inputDimIdx
2295 << " should match init dim " << initDimIdx
2296 << ". input: " << inputShape[inputDimIdx]
2297 << ", init: " << initShape[initDimIdx];
2298 }
2299
2300 return success();
2301}
2302
2303SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2304 int64_t rank = getInit().getType().getRank();
2305 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2306}
2307
2308ArrayAttr BroadcastOp::getIndexingMaps() {
2309 Builder builder(getContext());
2310 int64_t rank = getInit().getType().getRank();
2311 return builder.getAffineMapArrayAttr(
2312 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2313 builder.getMultiDimIdentityMap(rank)});
2314}
2315
2316void BroadcastOp::getEffects(
2317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2318 &effects) {
2319 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2320}
2321
2322Speculation::Speculatability BroadcastOp::getSpeculatability() {
2323 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2324}
2325
2326/// Fold back-to-back broadcasts together.
2327struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2328 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2329
2330 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2331 PatternRewriter &rewriter) const override {
2332 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2333 if (!defBroadcastOp)
2334 return failure();
2335 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2336 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2337 SmallVector<int64_t> foldedDims(dimensions);
2338 Value init = broadcastOp.getInit();
2339 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2340 // Mapping from input dims to init dims.
2341 SmallVector<int64_t> dimMap;
2342 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2343 if (!llvm::is_contained(dimensions, dim))
2344 dimMap.push_back(dim);
2345 }
2346 for (auto dim : defDimensions)
2347 foldedDims.push_back(dimMap[dim]);
2348
2349 llvm::sort(foldedDims);
2350 rewriter.replaceOpWithNewOp<BroadcastOp>(
2351 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2352 return success();
2353 }
2354};
2355
2356void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357 MLIRContext *context) {
2358 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2359}
2360
2361//===----------------------------------------------------------------------===//
2362// YieldOp
2363//===----------------------------------------------------------------------===//
2364
2365void linalg::YieldOp::print(OpAsmPrinter &p) {
2366 if (getNumOperands() > 0)
2367 p << ' ' << getOperands();
2368 p.printOptionalAttrDict((*this)->getAttrs());
2369 if (getNumOperands() > 0)
2370 p << " : " << getOperandTypes();
2371}
2372
2373ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2374 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2375 SmallVector<Type, 2> types;
2376 SMLoc loc = parser.getCurrentLocation();
2377 return failure(parser.parseOperandList(opInfo) ||
2378 parser.parseOptionalAttrDict(result.attributes) ||
2379 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2380 parser.resolveOperands(opInfo, types, loc, result.operands));
2381}
2382
2383// Check the operand number and types must match the element types of the
2384// LinalgOp interface's shaped operands.
2385static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2386 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2387 return op.emitOpError("expected number of yield values (")
2388 << op.getNumOperands()
2389 << ") to match the number of inits / outs operands of the enclosing "
2390 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2391
2392 for (OpOperand &opOperand : op->getOpOperands()) {
2393 OpOperand *outputOperand =
2394 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2395 Type elementType = outputOperand->get().getType();
2396 if (isa<MemRefType, RankedTensorType>(elementType))
2397 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2398 if (opOperand.get().getType() != elementType)
2399 return op.emitOpError("type of yield operand ")
2400 << (opOperand.getOperandNumber() + 1) << " ("
2401 << opOperand.get().getType() << ") doesn't match "
2402 << "the element type of the enclosing linalg.generic op ("
2403 << elementType << ")";
2404 }
2405 return success();
2406}
2407
2408LogicalResult linalg::YieldOp::verify() {
2409 auto *parentOp = (*this)->getParentOp();
2410 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2411 return emitOpError("expected single non-empty parent region");
2412
2413 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2414 return verifyYield(*this, linalgOp);
2415
2416 return emitOpError("expected parent op with LinalgOp interface");
2417}
2418
2419//===----------------------------------------------------------------------===//
2420// IndexOp
2421//===----------------------------------------------------------------------===//
2422
2423LogicalResult IndexOp::verify() {
2424 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2425 if (!linalgOp)
2426 return emitOpError("expected parent op with LinalgOp interface");
2427 if (linalgOp.getNumLoops() <= getDim())
2428 return emitOpError("expected dim (")
2429 << getDim() << ") to be lower than the number of loops ("
2430 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2431 return success();
2432}
2433
2434OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2435 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2436 // Bail out if `linalg.index` does not have a proper parent yet at this
2437 // point, e.g., when calling `createOrFold` during IR construction in
2438 // `genericOp::build`.
2439 if (!linalgOp)
2440 return OpFoldResult{};
2441
2442 // Index of unit dims is always 0.
2443 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2444 uint64_t dim = getDim();
2445 assert(dim < loopBounds.size() && "Dim is out of bounds");
2446 if (loopBounds[dim] == 1)
2447 return IntegerAttr::get(IndexType::get(getContext()), 0);
2448
2449 return OpFoldResult{};
2450}
2451
2452/////// Operations corresponding to library calls defined with Tablegen ////////
2453
2454#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2455
2456#define GET_OP_CLASSES
2457#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2458
2459#define GET_OP_CLASSES
2460#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2461#define GET_OP_CLASSES
2462#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2463
2464AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2465 unsigned rank,
2466 MLIRContext *context) {
2467 if (maybeMap)
2468 return *maybeMap;
2469 if (rank == 0)
2470 return AffineMap::get(context);
2471 return AffineMap::getMultiDimIdentityMap(rank, context);
2472}
2473
2475mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2476 MLIRContext *context) {
2478 res.reserve(num);
2479 for (unsigned i = 0; i < num; ++i)
2480 res.push_back(getAffineDimExpr(startIdx++, context));
2481 return res;
2482}
2483
2486 auto rangeA = llvm::make_range(a.begin(), a.end());
2487 auto rangeB = llvm::make_range(b.begin(), b.end());
2488 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2489 return llvm::to_vector<4>(concatRanges);
2490}
2491
2492static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2493 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2494 ss << "view";
2495 for (auto size : memref.getShape())
2496 if (size < 0)
2497 ss << "sx";
2498 else
2499 ss << size << "x";
2500 if (failed(appendMangledType(ss, memref.getElementType())))
2501 return failure();
2502 if (auto as = memref.getMemorySpace()) {
2503 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2504 ss << "as" << attr.getInt();
2505 else
2506 return failure();
2507 }
2508 return success();
2509 }
2510 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2511 ss << "vector";
2512 llvm::interleave(
2513 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2514 if (failed(appendMangledType(ss, vec.getElementType())))
2515 return failure();
2516 return success();
2517 }
2519 ss << t;
2520 return success();
2521 }
2522 return failure();
2523}
2524
2526 assert(isa<LinalgOp>(op));
2527 std::string name(op->getName().getStringRef().str());
2528 std::string fun = "";
2529 for (NamedAttribute kv : op->getAttrs()) {
2530 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2531 fun = stringifyEnum(ufa.getValue()).str() + "_";
2532 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2533 fun = stringifyEnum(bfa.getValue()).str() + "_";
2534 }
2535 }
2536 name.reserve(128);
2537 llvm::replace(name, '.', '_');
2538 llvm::raw_string_ostream ss(name);
2539 ss << "_" << fun;
2540 for (Type t : op->getOperandTypes()) {
2541 if (failed(appendMangledType(ss, t)))
2542 return std::string();
2543 ss << "_";
2544 }
2545 name.pop_back();
2546 return name;
2547}
2548
2549//===----------------------------------------------------------------------===//
2550// Canonicalizers and Folders.
2551//===----------------------------------------------------------------------===//
2552
2553namespace {
2554struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2556
2557 LogicalResult matchAndRewrite(LinalgOp op,
2558 PatternRewriter &rewriter) const override {
2559 for (OpOperand &opOperand : op->getOpOperands()) {
2560 // Linalg "inputs" may be either tensor or memref type.
2561 // tensor<0xelt_type> is a convention that may not always mean
2562 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2563 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2564 if (!mt)
2565 continue;
2566 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2567 rewriter.eraseOp(op);
2568 return success();
2569 }
2570 }
2571 return failure();
2572 }
2573};
2574
2575/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2576/// result that is more static than the linalg op.
2577struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2578 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2579
2580 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2581 PatternRewriter &rewriter) const override {
2582 if (!tensor::canFoldIntoProducerOp(castOp))
2583 return failure();
2584
2585 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2586 if (!linalgOp)
2587 return failure();
2588
2589 // Cast can be in conditionally reachable region, if which case folding will
2590 // generate invalid code. Only conservatively fold ops in same block for
2591 // now.
2592 if (castOp->getBlock() != linalgOp->getBlock())
2593 return failure();
2594
2595 OpBuilder::InsertionGuard guard(rewriter);
2596 rewriter.setInsertionPoint(linalgOp);
2597
2598 Location loc = linalgOp.getLoc();
2599 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2600 unsigned resultNumber = resultValue.getResultNumber();
2601 auto resultType =
2602 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2603 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2604 // going from a more dynamic shape to a less dynamic shape. If the producer
2605 // for this cast, i.e. producer of the out operand, is also an operation
2606 // that folds with tensor.cast consumer (like this pattern), the cast will
2607 // continue to propagate as far up the stack as it can go.
2608 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2609 Value newOperand =
2610 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2611 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2612 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2613 linalgOp.getDpsInits().end());
2614 outputOperands[resultNumber] = newOperand;
2615 newOperands.append(outputOperands.begin(), outputOperands.end());
2616
2617 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2618 linalgOp->result_type_end());
2619 resultTypes[resultNumber] = resultType;
2620 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2621
2622 // Create a tensor.cast operation back to the original type.
2623 Value castBack = tensor::CastOp::create(
2624 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2625
2626 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2627 results[resultNumber] = castBack;
2628 rewriter.replaceOp(linalgOp, results);
2629 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2630 return success();
2631 }
2632};
2633
2634/// For each of the operand in `operands` this function maps the static sizes of
2635/// dimensions to their affine dim expressions.
2636static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2637 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2638 for (OpOperand &opOperand : operands) {
2639 if (linalgOp.isScalar(&opOperand))
2640 continue;
2641 Value src = opOperand.get();
2642 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2643 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2644
2645 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2646 // `tensor.cast` operation and source of the cast operation has a static
2647 // shape, then assign it to the `sourceShape`.
2648 auto *parentOp = src.getDefiningOp();
2649 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2650 if (parentOp) {
2651 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2652 Value castSource = castOp.getSource();
2653 auto castSourceType =
2654 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2655 if (castSourceType && castSourceType.hasStaticShape())
2656 sourceShape = castSourceType.getShape();
2657 }
2658 }
2659
2660 // If the source shape's dimension has a static shape, map the affine dim
2661 // expression to the known static size.
2662 for (unsigned i = 0; i < sourceShape.size(); i++) {
2663 if (sourceType.isDynamicDim(i))
2664 continue;
2665 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2666 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2667 }
2668 }
2669}
2670
2671/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2672/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2673/// their result types is stored in `resultTypes`. If `opOperand` requires no
2674/// change then `changeNeeded` is false and same operand is added in the
2675/// `newOperands` list.
2676static void createNewOperandWithStaticSizes(
2677 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2678 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2679 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2680 bool &changeNeeded) {
2681 Value src = opOperand->get();
2682 newOperands.push_back(src);
2683 if (linalgOp.isScalar(opOperand))
2684 return;
2685 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2686 Type resultType = sourceType;
2687 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2688 resultTypes.push_back(resultType);
2689 return;
2690 }
2691 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2692 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2693 SmallVector<int64_t> newShape;
2694 // If operand is updated with new shape, `newOperandNeeded` will be
2695 // true.
2696 bool newOperandNeeded = false;
2697 for (unsigned i = 0; i < sourceShape.size(); i++) {
2698 int64_t dimShape = sourceShape[i];
2699 AffineExpr dimExpr = sourceMap.getResult(i);
2700 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2701 newShape.push_back(dimShape);
2702 continue;
2703 }
2704 // Dimension has a dynamic shape and corresponding affine dim
2705 // expression is present in the map. So assign the size for the
2706 // given affine dim expression to the dimension.
2707 newShape.push_back(affineExprToSize[dimExpr]);
2708 newOperandNeeded = true;
2709 }
2710 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2711 sourceType.getEncoding());
2712 if (newOperandNeeded) {
2713 changeNeeded = true;
2714 // Get the new operand value given its size and element type by
2715 // casting it.
2716 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2717 unsigned index = opOperand->getOperandNumber();
2718 newOperands[index] = newOperand;
2719 }
2720 if (linalgOp.isDpsInit(opOperand))
2721 resultTypes.push_back(resultType);
2722}
2723
2724/// Static shapes for the operands can be inferred if any one of the operands
2725/// have a static shape. This can be done by referring to the affine dim
2726/// expressions for the operand.
2727struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2728 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2729
2730 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2731 PatternRewriter &rewriter) const override {
2732 if (!linalgOp.hasPureTensorSemantics())
2733 return failure();
2734
2735 // Maps must be projected permutations.
2736 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2737 return !map.isProjectedPermutation();
2738 }))
2739 return failure();
2740
2741 // Maps affine dim expressions to the static size of that dimension.
2742 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2743 Location loc = linalgOp.getLoc();
2744
2745 // For each of the affine dim expression, check if the size is known. If
2746 // known add that in the map.
2747 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2748
2749 SmallVector<Value> newOperands;
2750 SmallVector<Type> resultTypes;
2751
2752 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2753 // change in their types.
2754 bool changeNeeded = false;
2755 newOperands.reserve(linalgOp->getNumOperands());
2756 resultTypes.reserve(linalgOp.getNumDpsInits());
2757
2758 // Iterate over all the operands and update the static sizes.
2759 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2760 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2761 affineExprToSize, linalgOp, newOperands,
2762 resultTypes, changeNeeded);
2763 }
2764
2765 // If the generic op has all the required static information, no
2766 // canonicalization needed.
2767 if (!changeNeeded)
2768 return failure();
2769
2770 // Clone op.
2771 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2772 SmallVector<Value> replacements;
2773 replacements.reserve(newOp->getNumResults());
2774 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2775 Value newResult = std::get<1>(it);
2776 Value oldResult = std::get<0>(it);
2777 Type newType = newResult.getType();
2778 Type oldType = oldResult.getType();
2779 replacements.push_back(
2780 (newType != oldType)
2781 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2782 : newResult);
2783 }
2784 rewriter.replaceOp(linalgOp, replacements);
2785 return success();
2786 }
2787};
2788
2789} // namespace
2790
2791// All named ops canonicalizers and folders are auto-generated in the
2792// .cpp.inc.
2793
2794//===----------------------------------------------------------------------===//
2795// SoftmaxOp
2796//===----------------------------------------------------------------------===//
2797
2798LogicalResult SoftmaxOp::verify() {
2799 ShapedType inputType = getInputOperandType();
2800 ShapedType outputType = getOutputOperandType();
2801
2802 ArrayRef<int64_t> inputShape = inputType.getShape();
2803 ArrayRef<int64_t> outputShape = outputType.getShape();
2804 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2805 return emitOpError("incompatible output shape");
2806
2807 int64_t inputRank = getInputOperandRank();
2808 int64_t dimension = getDimension();
2809 if ((dimension < 0) || (dimension >= inputRank))
2810 return emitOpError("incorrect dimension specified");
2811
2812 return success();
2813}
2814
2815SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2816 int64_t operandRank = getInputOperandRank();
2817 SmallVector<Range> loopBounds(operandRank);
2818 Location loc = getLoc();
2819 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2820 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2821 Value source = getInput();
2822 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2823 loopBounds[dim].offset = zero;
2824 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2825 loopBounds[dim].stride = one;
2826 }
2827 return loopBounds;
2828}
2829
2830SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2831 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2832 utils::IteratorType::parallel);
2833 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2834 return iteratorTypes;
2835}
2836
2837FailureOr<TilingResult>
2838SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2839 ArrayRef<OpFoldResult> offsets,
2840 ArrayRef<OpFoldResult> sizes) {
2841 int64_t rank = getInputOperandRank();
2842 auto oneAttr = builder.getI64IntegerAttr(1);
2843 SmallVector<OpFoldResult> strides(rank, oneAttr);
2844 SmallVector<Value> tiledOperands;
2845 Operation *inputSlice =
2846 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2847 if (!inputSlice) {
2848 return emitOpError("failed to compute input slice");
2849 }
2850 tiledOperands.emplace_back(inputSlice->getResult(0));
2851 Operation *outputSlice =
2852 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2853 if (!outputSlice) {
2854 return emitOpError("failed to compute output slice");
2855 }
2856 tiledOperands.emplace_back(outputSlice->getResult(0));
2857
2858 SmallVector<Type, 4> resultTypes;
2859 if (hasPureTensorSemantics())
2860 resultTypes.push_back(tiledOperands[1].getType());
2861 Operation *tiledOp =
2862 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2863
2864 return TilingResult{
2865 {tiledOp},
2866 SmallVector<Value>(tiledOp->getResults()),
2867 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2868}
2869
2870LogicalResult SoftmaxOp::getResultTilePosition(
2871 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2872 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2873 SmallVector<OpFoldResult> &resultSizes) {
2874 if (resultNumber == 0) {
2875 resultOffsets.assign(offsets.begin(), offsets.end());
2876 resultSizes.assign(sizes.begin(), sizes.end());
2877 return success();
2878 }
2879 return failure();
2880}
2881
2882// cast(dynamic) -> static.
2883LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2884 return memref::foldMemRefCast(*this);
2885}
2886
2887LogicalResult
2888SoftmaxOp::reifyResultShapes(OpBuilder &b,
2889 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2890 SmallVector<OpFoldResult> shapes;
2891 Location loc = getOperation()->getLoc();
2892 IRRewriter rewriter(b);
2893 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2894 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2895 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2896 if (!outputShapedType.isDynamicDim(dim)) {
2897 // Static dim: Return IntegerAttr.
2898 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2899 } else {
2900 // Dynamic dim: Return Value.
2901 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2902 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2903 }
2904 }
2905 reifiedReturnShapes.emplace_back(std::move(shapes));
2906 return success();
2907}
2908
2909void SoftmaxOp::getEffects(
2910 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2911 &effects) {
2912 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2913 if (!llvm::isa<MemRefType>(operand.getType()))
2914 continue;
2915 effects.emplace_back(MemoryEffects::Read::get(),
2916 &getOperation()->getOpOperand(index), /*stage=*/0,
2917 /*effectOnFullRegion=*/true,
2919 }
2920
2921 for (OpOperand &operand : getDpsInitsMutable()) {
2922 if (!llvm::isa<MemRefType>(operand.get().getType()))
2923 continue;
2924 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2925 /*effectOnFullRegion=*/true,
2927 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2928 /*effectOnFullRegion=*/true,
2930 }
2931}
2932
2933// Helper functions for softmax decomposition.
2934// @{
2935
2936// Helper function to produce the iterator types (reduction or parallel) and
2937// affine maps for the iterators used in the decomposition of softmax.
2938// This method creates:
2939// If allParallel == true:
2940// - iterator type: {parallel, ..., parallel}
2941// - affine maps:
2942// -- identity with inputRank dimensions.
2943// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2944// where N == inputRank.
2945//
2946// If allParallel == false:
2947// - iterator type at dim(i) == parallel for i != \p dim and
2948// dim(dim) == reduction.
2949// - affine map:
2950// -- identity with inputRank dimensions.
2951// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2952// where N == inputRank.
2953static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2955 int64_t dim, bool allParallel = false) {
2956 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2957 utils::IteratorType::parallel);
2958 if (!allParallel)
2959 iteratorTypes[dim] = utils::IteratorType::reduction;
2960 MLIRContext *ctxt = builder.getContext();
2961 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2962 SmallVector<AffineExpr, 2> affineExprs;
2963 for (int i = 0; i < inputRank; i++) {
2964 if (i != dim)
2965 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2966 }
2967 auto reductionMap =
2968 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2969 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2970 return std::make_tuple(iteratorTypes, indexingMaps);
2971}
2972
2973// Helper function to produce a linalg.generic that computes a reduction on
2974// dimension \p dim with the operation type \p T.
2975template <typename T>
2976static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2977 int64_t dim) {
2978 auto inputType = cast<ShapedType>(input.getType());
2979 ArrayRef<int64_t> inputShape = inputType.getShape();
2980 int64_t inputRank = inputShape.size();
2981 auto [iteratorTypes, indexingMaps] =
2982 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2983 assert(indexingMaps.size() == 2 &&
2984 "We should have two maps: 1 for the input, 1 for the output");
2985 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2986
2987 auto genericOp = linalg::GenericOp::create(
2988 builder, loc, output.getType(), input, output, indexingMaps,
2989 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2990 Value result = T::create(b, loc, args[0], args[1]);
2991 linalg::YieldOp::create(b, loc, result);
2992 });
2993 return genericOp.getResult(0);
2994}
2995
2996/// Produce a linalg generic that computes the second step of the softmax
2997/// decomposition: res = exp(input - max), where \p max is the max of \p input
2998/// on dimension \p dim.
2999static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3000 Value max, Value output, int64_t dim) {
3001 auto inputType = cast<ShapedType>(input.getType());
3002 ArrayRef<int64_t> inputShape = inputType.getShape();
3003 int64_t inputRank = inputShape.size();
3004 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3005 builder, inputRank, dim, /*allParallel=*/true);
3006 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3007 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3008 // Add the affine map for the output argument.
3009 indexingMaps.push_back(indexingMaps[0]);
3010 auto genericOp = linalg::GenericOp::create(
3011 builder, loc, input.getType(), ValueRange{input, max}, output,
3012 indexingMaps, iteratorTypes,
3013 [&](OpBuilder &b, Location loc, ValueRange args) {
3014 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3015 Value result = math::ExpOp::create(b, loc, diff);
3016 linalg::YieldOp::create(b, loc, result);
3017 });
3018 return genericOp.getResult(0);
3019}
3020
3021/// Produce a linalg generic that computes the final step of the softmax
3022/// decomposition.
3023/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3024/// yield n / d
3025/// }
3026static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3027 Value denominator, Value output, int64_t dim) {
3028 auto inputType = cast<ShapedType>(numerator.getType());
3029 ArrayRef<int64_t> inputShape = inputType.getShape();
3030 int64_t inputRank = inputShape.size();
3031 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3032 builder, inputRank, dim, /*allParallel=*/true);
3033 assert(indexingMaps.size() == 2 &&
3034 "We should have one map for each input (2)");
3035 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3036 // Add the affine map for the output tensor.
3037 indexingMaps.push_back(indexingMaps[0]);
3038 auto genericOp = linalg::GenericOp::create(
3039 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3040 output, indexingMaps, iteratorTypes,
3041 [&](OpBuilder &b, Location loc, ValueRange args) {
3042 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3043 linalg::YieldOp::create(b, loc, result);
3044 });
3045 return genericOp.getResult(0);
3046}
3047// @} End helper functions for softmax decomposition.
3048
3049/// Given an N-dimensional tensor x, this method converts
3050/// softmax(x) to the following sequence of operations:
3051///
3052/// 1. Compute the max of x along dimension d. This results
3053/// in a N-1 dimensional tensor m.
3054/// m = max(x, dim = d)
3055///
3056/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3057/// a N dimensional tensor z.
3058/// z = exp(x - m)
3059///
3060/// 3. Compute the sum of z along dimension d. This results in
3061/// a N-1 dimensional tensor l.
3062/// l = sum(z, dim = d)
3063///
3064/// 4. Divide z and l. This gives the N-dimensional softmax.
3065/// softmax = z / l
3066///
3067FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3068 OpBuilder::InsertionGuard guard(b);
3069 b.setInsertionPoint(*this);
3070 Location loc = getLoc();
3071 Value input = getInput();
3072 ShapedType inputType = getInputOperandType();
3073 Type elementType = inputType.getElementType();
3074 int64_t reductionDim = getDimension();
3075 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3076 Value output = getOutput();
3077 dims.erase(dims.begin() + reductionDim);
3078 // Step 1: Compute max along dim.
3079 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3080 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3081 elementType, b, loc,
3082 /*useOnlyFiniteValue=*/true);
3083 Value neutralForMaxFInit =
3084 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3085 .result();
3086 Value max =
3087 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3088
3089 // Step 2: Subtract max from input and exponentiate.
3090 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3091
3092 // Step 3: Compute sum along dim.
3093 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3094 b, loc, /*useOnlyFiniteValue=*/true);
3095 Value zeroInit =
3096 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3097 Value denominator =
3098 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3099
3100 // Step 4: Compute softmax.
3101 Value result =
3102 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3103 return SmallVector<Value>{result};
3104}
3105
3106//===----------------------------------------------------------------------===//
3107// WinogradFilterTransformOp
3108//===----------------------------------------------------------------------===//
3109
3110LogicalResult WinogradFilterTransformOp::verify() {
3111 auto filterType = cast<ShapedType>(getFilter().getType());
3112 ArrayRef<int64_t> filterShape = filterType.getShape();
3113 int64_t filterH = filterShape[getFilterHDim()];
3114 int64_t filterW = filterShape[getFilterWDim()];
3115 WinogradConv2DFmr fmr = getFmr();
3116 int64_t m, r;
3117 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3118
3119 if (filterH != r && filterH != 1)
3120 return emitOpError("expect filter height either equals to r or 1");
3121 if (filterW != r && filterW != 1)
3122 return emitOpError("expect filter width either equals to r or 1");
3123 if (filterH == 1 && filterW == 1)
3124 return emitOpError("expect either filter height or width equals to r");
3125
3126 SmallVector<int64_t> expectedOutputShape;
3127 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3128 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3130 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3131
3132 auto outputType = cast<ShapedType>(getOutput().getType());
3133 ArrayRef<int64_t> outputShape = outputType.getShape();
3134 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3135 return emitOpError("the output shape is not expected");
3136 }
3137 return success();
3138}
3139
3140SmallVector<Range>
3141WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3142 Location loc = getLoc();
3143 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3144 IntegerAttr oneAttr = builder.getIndexAttr(1);
3145 Value filter = getFilter();
3146 int64_t filterRank = getFilterOperandRank();
3147 SmallVector<Range> loopBounds(filterRank);
3148 for (unsigned dim = 0; dim < filterRank; ++dim) {
3149 loopBounds[dim].offset = zeroAttr;
3150 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3151 loopBounds[dim].stride = oneAttr;
3152 }
3153 return loopBounds;
3154}
3155
3156SmallVector<utils::IteratorType>
3157WinogradFilterTransformOp::getLoopIteratorTypes() {
3158 int64_t filterRank = getFilterOperandRank();
3159 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3160 utils::IteratorType::parallel);
3161 return iteratorTypes;
3162}
3163
3164LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3165 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3166 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3167 SmallVector<OpFoldResult> &resultSizes) {
3168 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3169 ShapedType filterType = getFilterOperandType();
3170 ArrayRef<int64_t> filterShape = filterType.getShape();
3171 int64_t filterH = filterShape[getFilterHDim()];
3172 int64_t filterW = filterShape[getFilterWDim()];
3173 WinogradConv2DFmr fmr = getFmr();
3174 int64_t m, r;
3175 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3176 int64_t alpha = m + r - 1;
3177 int64_t alphaH = filterH != 1 ? alpha : 1;
3178 int64_t alphaW = filterW != 1 ? alpha : 1;
3179 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3180 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3181
3182 resultOffsets.append(
3183 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3184 resultSizes.append(
3185 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3186
3187 return success();
3188}
3189
3190/// Implement tiling for winograd_filter_transform
3191/// The input of winograd_filter_transform is (F, KH, KW, C).
3192/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3193/// Users can specify the tile sizes of F and C.
3194/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3195/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3196FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3197 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3198 ArrayRef<OpFoldResult> sizes) {
3199 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3200 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3201 ShapedType filterType = getFilterOperandType();
3202 ArrayRef<int64_t> filterShape = filterType.getShape();
3203 int64_t filterH = filterShape[getFilterHDim()];
3204 int64_t filterW = filterShape[getFilterWDim()];
3205 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3206 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3207 SmallVector<Value> tiledOperands;
3208 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3209
3210 sliceOffsets.append(
3211 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3212 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3213 sizes[getFilterCDim()]});
3214 int64_t filterRank = getFilterOperandRank();
3215 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3216 Location loc = getLoc();
3217 auto filterSlice = tensor::ExtractSliceOp::create(
3218 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3219 tiledOperands.emplace_back(filterSlice);
3220
3221 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3222 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3223 resultSizes)))
3224 return failure();
3225
3226 int64_t outputRank = getOutputOperandRank();
3227 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3228 auto outputSlice = tensor::ExtractSliceOp::create(
3229 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3230 tiledOperands.emplace_back(outputSlice);
3231
3232 SmallVector<Type> resultTypes;
3233 resultTypes.push_back(tiledOperands[1].getType());
3234 Operation *tiledOp =
3235 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3236
3237 return TilingResult{
3238 {tiledOp},
3239 SmallVector<Value>(tiledOp->getResults()),
3240 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3241}
3242
3243//===----------------------------------------------------------------------===//
3244// WinogradInputTransformOp
3245//===----------------------------------------------------------------------===//
3246
3247LogicalResult WinogradInputTransformOp::verify() {
3248 auto inputType = cast<ShapedType>(getInput().getType());
3249 ArrayRef<int64_t> inputShape = inputType.getShape();
3250 int64_t inputH = inputShape[getInputHDim()];
3251 int64_t inputW = inputShape[getInputWDim()];
3252 WinogradConv2DFmr fmr = getFmr();
3253 int64_t m, r;
3254 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3255 int64_t tileSize = m + r - 1;
3256
3257 auto outputType = cast<ShapedType>(getOutput().getType());
3258 ArrayRef<int64_t> outputShape = outputType.getShape();
3259 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3260 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3261
3262 SmallVector<int64_t> expectedOutputShape(6, inputH);
3263 if (ShapedType::isDynamic(inputH)) {
3264 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3265 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3266 } else {
3267 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3268 expectedOutputShape[getOutputTileHDim()] =
3269 leftTransform ? (inputH - (r - 1)) / m : inputH;
3270 }
3271 if (ShapedType::isDynamic(inputW)) {
3272 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3273 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3274 } else {
3275 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3276 expectedOutputShape[getOutputTileWDim()] =
3277 rightTransform ? (inputW - (r - 1)) / m : inputW;
3278 }
3279 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3280 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3281
3282 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3283 return emitOpError("the output shape is not expected");
3284 }
3285 return success();
3286}
3287
3288SmallVector<Range>
3289WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3290 Location loc = getLoc();
3291 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3292 IntegerAttr oneAttr = builder.getIndexAttr(1);
3293 Value output = getOutput();
3294 int64_t outputRank = getOutputOperandRank();
3295 SmallVector<Range> loopBounds(outputRank);
3296 for (unsigned dim = 0; dim < outputRank; ++dim) {
3297 loopBounds[dim].offset = zeroAttr;
3298 // alphaH, alphaW, tileH, tileW, N, C
3299 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3300 loopBounds[dim].stride = oneAttr;
3301 }
3302 return loopBounds;
3303}
3304
3305SmallVector<utils::IteratorType>
3306WinogradInputTransformOp::getLoopIteratorTypes() {
3307 int64_t outputRank = getOutputOperandRank();
3308 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3309 utils::IteratorType::parallel);
3310 return iteratorTypes;
3311}
3312
3313LogicalResult WinogradInputTransformOp::getResultTilePosition(
3314 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3315 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3316 SmallVector<OpFoldResult> &resultSizes) {
3317 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3318 ShapedType outputType = getOutputOperandType();
3319 ArrayRef<int64_t> outputShape = outputType.getShape();
3320 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3321 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3322
3323 WinogradConv2DFmr fmr = getFmr();
3324 int64_t m, r;
3325 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3326 int64_t alpha = m + r - 1;
3327 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3328 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3329
3330 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3331 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3332
3333 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3334 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3335 offsets[getOutputCDim()]});
3336 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3337 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3338 sizes[getOutputCDim()]});
3339
3340 return success();
3341}
3342
3343/// Implement tiling for winograd_input_transform
3344/// The input of winograd_input_transform is (N, H, W, C).
3345/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3346/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3347/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3348/// the values for the sizes of tileH, tileW, N, C for one tile.
3349FailureOr<TilingResult>
3350WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3351 ArrayRef<OpFoldResult> offsets,
3352 ArrayRef<OpFoldResult> sizes) {
3353 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3354 WinogradConv2DFmr fmr = getFmr();
3355 int64_t m, r;
3356 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3357
3358 ShapedType outputType = getOutputOperandType();
3359 ArrayRef<int64_t> outputShape = outputType.getShape();
3360 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3361 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3362
3363 Location loc = getLoc();
3364 MLIRContext *context = builder.getContext();
3365 auto identityAffineMap =
3366 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3367 auto offsetAffineMap =
3368 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3369 Value mappedOffsetH = affine::makeComposedAffineApply(
3370 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3371 offsets[getOutputTileHDim()]);
3372 Value mappedOffsetW = affine::makeComposedAffineApply(
3373 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3374 offsets[getOutputTileWDim()]);
3375 auto sizeAffineMap = AffineMap::get(
3376 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3377 Value mappedSizeH = affine::makeComposedAffineApply(
3378 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3379 Value mappedSizeW = affine::makeComposedAffineApply(
3380 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3381
3382 SmallVector<Value> tiledOperands;
3383 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3384
3385 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3386 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3387 sliceOffsets.append(
3388 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3389 OpFoldResult sizeH =
3390 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3391 OpFoldResult sizeW =
3392 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3393 sliceSizes.append(
3394 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3395 int64_t inputRank = getInputOperandRank();
3396 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3397 auto inputSlice = tensor::ExtractSliceOp::create(
3398 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3399 tiledOperands.emplace_back(inputSlice);
3400
3401 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3402 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3403 resultSizes)))
3404 return failure();
3405
3406 int64_t outputRank = getOutputOperandRank();
3407 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3408 auto outputSlice = tensor::ExtractSliceOp::create(
3409 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3410 tiledOperands.emplace_back(outputSlice);
3411
3412 SmallVector<Type> resultTypes;
3413 resultTypes.push_back(tiledOperands[1].getType());
3414 Operation *tiledOp =
3415 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3416
3417 return TilingResult{
3418 {tiledOp},
3419 SmallVector<Value>(tiledOp->getResults()),
3420 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3421}
3422
3423//===----------------------------------------------------------------------===//
3424// WinogradOutputTransformOp
3425//===----------------------------------------------------------------------===//
3426
3427LogicalResult WinogradOutputTransformOp::verify() {
3428 auto valueType = cast<ShapedType>(getValue().getType());
3429 ArrayRef<int64_t> valueShape = valueType.getShape();
3430 int64_t valueH = valueShape[getValueAlphaHDim()];
3431 int64_t valueW = valueShape[getValueAlphaWDim()];
3432 int64_t valueTileH = valueShape[getValueTileHDim()];
3433 int64_t valueTileW = valueShape[getValueTileWDim()];
3434 WinogradConv2DFmr fmr = getFmr();
3435 int64_t m, r;
3436 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3437 bool leftTransform = valueH != 1;
3438 bool rightTransform = valueW != 1;
3439
3440 int64_t outputRank = getOutputOperandRank();
3441 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3442 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3443 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3444 } else {
3445 if (valueH != (leftTransform ? m + r - 1 : 1))
3446 return emitOpError("expect input height equals to input tile size");
3447 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3448 }
3449 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3450 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3451 } else {
3452 if (valueW != (rightTransform ? m + r - 1 : 1))
3453 return emitOpError("expect input width equals to input tile size");
3454 expectedOutputShape[getOutputWDim()] =
3455 (rightTransform ? m : 1) * valueTileW;
3456 }
3457 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3458 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3459
3460 auto outputType = cast<ShapedType>(getOutput().getType());
3461 ArrayRef<int64_t> outputShape = outputType.getShape();
3462 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3463 return emitOpError("the output shape is not expected");
3464 }
3465 return success();
3466}
3467
3468SmallVector<Range>
3469WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3470 Location loc = getLoc();
3471 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3472 IntegerAttr oneAttr = builder.getIndexAttr(1);
3473 Value value = getValue();
3474 int64_t valueRank = getValueOperandRank();
3475 SmallVector<Range> loopBounds(valueRank);
3476 for (unsigned dim = 0; dim < valueRank; ++dim) {
3477 loopBounds[dim].offset = zeroAttr;
3478 // alphaH, alphaW, tileH, tileW, N, F
3479 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3480 loopBounds[dim].stride = oneAttr;
3481 }
3482 return loopBounds;
3483}
3484
3485SmallVector<utils::IteratorType>
3486WinogradOutputTransformOp::getLoopIteratorTypes() {
3487 int64_t valueRank = getValueOperandRank();
3488 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3489 utils::IteratorType::parallel);
3490 return iteratorTypes;
3491}
3492
3493LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3494 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3495 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3496 SmallVector<OpFoldResult> &resultSizes) {
3497 WinogradConv2DFmr fmr = getFmr();
3498 int64_t m, r;
3499 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3500
3501 Location loc = getLoc();
3502 MLIRContext *context = builder.getContext();
3503 auto identityAffineMap =
3504 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3505 auto affineMap =
3506 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3507
3508 ShapedType valueType = getValueOperandType();
3509 ArrayRef<int64_t> valueShape = valueType.getShape();
3510 int64_t valueH = valueShape[0];
3511 int64_t valueW = valueShape[1];
3512 Value mappedOffsetH = affine::makeComposedAffineApply(
3513 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3514 offsets[getValueTileHDim()]);
3515 Value mappedOffsetW = affine::makeComposedAffineApply(
3516 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3517 offsets[getValueTileWDim()]);
3518 Value mappedSizeH = affine::makeComposedAffineApply(
3519 builder, loc, affineMap, sizes[getValueTileHDim()]);
3520 Value mappedSizeW = affine::makeComposedAffineApply(
3521 builder, loc, affineMap, sizes[getValueTileWDim()]);
3522
3523 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3524 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3525 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3526 OpFoldResult sizeH =
3527 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3528 OpFoldResult sizeW =
3529 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3530
3531 resultOffsets.append(
3532 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3533 resultSizes.append(
3534 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3535 return success();
3536}
3537
3538/// Implement tiling for winograd_output_transform
3539/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3540/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3541/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3542/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3543/// for the sizes of tileH, tileW, N, F for one tile.
3544FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3545 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3546 ArrayRef<OpFoldResult> sizes) {
3547 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3548 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3549 Location loc = getLoc();
3550 SmallVector<Value> tiledOperands;
3551 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3552
3553 ShapedType valueType = getValueOperandType();
3554 ArrayRef<int64_t> valueShape = valueType.getShape();
3555 int64_t alphaH = valueShape[getValueAlphaHDim()];
3556 int64_t alphaW = valueShape[getValueAlphaWDim()];
3557 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3558 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3559
3560 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3561 offsets[getValueTileWDim()], offsets[getValueNDim()],
3562 offsets[getValueFDim()]});
3563 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3564 sizes[getValueTileWDim()], sizes[getValueNDim()],
3565 sizes[getValueFDim()]});
3566 int64_t valueRank = getValueOperandRank();
3567 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3568 auto valueSlice = tensor::ExtractSliceOp::create(
3569 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3570 tiledOperands.emplace_back(valueSlice);
3571
3572 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3573 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3574 resultSizes)))
3575 return failure();
3576
3577 int64_t outputRank = getOutputOperandRank();
3578 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3579 auto outputSlice = tensor::ExtractSliceOp::create(
3580 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3581 tiledOperands.emplace_back(outputSlice);
3582
3583 SmallVector<Type> resultTypes;
3584 resultTypes.push_back(tiledOperands[1].getType());
3585 Operation *tiledOp =
3586 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3587
3588 return TilingResult{
3589 {tiledOp},
3590 SmallVector<Value>(tiledOp->getResults()),
3591 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3592}
3593
3594//===----------------------------------------------------------------------===//
3595// LinalgDialect
3596// TODO: Merge with the LinalgDialect block at the bottom
3597//===----------------------------------------------------------------------===//
3598
3599// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3600static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3601 auto explicitRange = subMap.getResults();
3602 auto defaultRange = fullMap.getResults();
3603 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3604 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3605 llvm::set_union(explicitSet, defaultSet);
3606 return explicitSet == defaultSet;
3607}
3608
3609/// Check if the user defined map is valid broadcast map. Here broadcast
3610/// indexing maps are defined in context of corresponding default indexing maps
3611/// for the given Op. This way the check becomes very simple i.e just check the
3612/// number of result dims.
3613/// Returns true if the explictMap is broadcasted with respect to the
3614/// defaultMap.
3615static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3616 return explictMap.getNumResults() < defaultMap.getNumResults();
3617}
3618
3619/// Verifies the broadcast and transpose semantic sepecified by the explicit
3620/// indexing map for the MatmulOp \p op for each operand specified by \p
3621/// opIndex.
3622static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3623 unsigned opIndex) {
3624 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3625 SmallVector<AffineMap, 3> defaultIndexingMaps =
3626 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3627
3628 auto opIndexingMap = opIndexingMaps[opIndex];
3629 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3630 // Check general validity of indexing map results.
3631 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3632 return matmulOp->emitOpError()
3633 << "Unexpected dim expression in map result.";
3634
3635 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3636 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3637 return matmulOp->emitOpError()
3638 << "Invalid broadcast requested, should be (d2).";
3639 }
3640 return success();
3641 }
3642 return success();
3643}
3644
3645// Check general validity of input indexing map of
3646// BatchMatmulOp/BatchReduceMatmulOp.
3647template <typename OpTy>
3648static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3649 AffineMap opIndexingMap,
3650 AffineMap defaultIndexingMap, bool isLHS) {
3651 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3652 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3653 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3654 // Check the result dims are valid.
3655 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3656 return batchVariantMatmulOp->emitOpError()
3657 << "Unexpected result dim expression (outside the set of default "
3658 "result dims).";
3659
3660 // Check for valid number of result dims of input maps.
3661 if (opIndexingMap.getNumResults() > 3)
3662 return batchVariantMatmulOp->emitOpError()
3663 << "no. of result dim expressions exceeds 3.";
3664
3665 auto hasValidBatchDim = [](AffineMap map) {
3666 AffineExpr batchDim = map.getResult(0);
3667 return batchDim.isFunctionOfDim(0);
3668 };
3669
3670 // Check if the requested broadcast is valid.
3671 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3672 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3673 return batchVariantMatmulOp->emitOpError()
3674 << "Invalid broadcast requested.";
3675 } else if (!hasValidBatchDim(opIndexingMap)) {
3676 return batchVariantMatmulOp->emitOpError()
3677 << "Invalid batch dimension expression.";
3678 }
3679 return success();
3680}
3681
3682/// This function checks if the given AffineMap for the output of a
3683/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3684/// dimensions and if the output map result dimensions are valid.
3685template <typename OpTy>
3686static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3687 AffineMap opIndexingMap) {
3688 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3689 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3690 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3691 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3692 opIndexingMap.getNumResults() != 3) {
3693
3694 return batchVariantMatmulOp->emitOpError()
3695 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3696 << ").";
3697 }
3698 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3699 opIndexingMap.getNumResults() != 2) {
3700 return batchVariantMatmulOp->emitOpError()
3701 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3702 << ").";
3703 }
3704
3705 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3706 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3707 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3708 outputMap.getResult(1).isFunctionOfDim(1) &&
3709 outputMap.getResult(2).isFunctionOfDim(2)
3710 : outputMap.getResult(0).isFunctionOfDim(1) &&
3711 outputMap.getResult(1).isFunctionOfDim(2);
3712 };
3713
3714 if (!areValidOutputResultDim(opIndexingMap)) {
3715 return batchVariantMatmulOp->emitOpError()
3716 << "Invalid output map result dimension.";
3717 }
3718
3719 return success();
3720}
3721
3722/// Verifies the broadcast and transpose semantic specified by the explicit
3723/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3724/// specified by opIndex.
3725template <typename OpTy>
3726static LogicalResult
3728 unsigned opIndex) {
3729 SmallVector<AffineMap, 3> opIndexingMaps =
3730 batchVariantMatmulOp.getIndexingMapsArray();
3731 SmallVector<AffineMap, 3> defaultIndexingMaps =
3732 batchVariantMatmulOp.getDefaultIndexingMaps(
3733 batchVariantMatmulOp->getContext());
3734
3735 if (opIndexingMaps.size() != 3)
3736 return batchVariantMatmulOp->emitOpError()
3737 << "Indexing_map attribute must have 3 affine maps.";
3738
3739 auto opIndexingMap = opIndexingMaps[opIndex];
3740 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3741
3742 if (opIndex == 2 &&
3743 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3744 return failure();
3745
3746 if (opIndex != 2 &&
3747 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3748 defaultIndexingMap, opIndex == 0)))
3749 return failure();
3750
3751 return success();
3752}
3753
3754namespace mlir {
3755namespace linalg {
3756
3757std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3758 if (m == 2 && r == 3)
3759 return WinogradConv2DFmr::F_2_3;
3760 if (m == 4 && r == 3)
3761 return WinogradConv2DFmr::F_4_3;
3762 if (m == 2 && r == 5)
3763 return WinogradConv2DFmr::F_2_5;
3764 return std::nullopt;
3765}
3766
3767std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3768 switch (fmr) {
3769 case WinogradConv2DFmr::F_2_3:
3770 return {2, 3};
3771 case WinogradConv2DFmr::F_4_3:
3772 return {4, 3};
3773 case WinogradConv2DFmr::F_2_5:
3774 return {2, 5};
3775 }
3776 llvm_unreachable("Unkown WinogradConv2DFmr");
3777}
3778
3779//===----------------------------------------------------------------------===//
3780// MatMulOp
3781//===----------------------------------------------------------------------===//
3782
3783static FailureOr<SmallVector<SmallVector<int64_t>>>
3786 for (auto map : maps) {
3787 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3788 if (!attr)
3789 return failure();
3791 for (auto result : attr.getAffineMap().getResults()) {
3792 auto dim = dyn_cast<AffineDimExpr>(result);
3793 if (!dim)
3794 return failure();
3795 pos.push_back(dim.getPosition());
3796 }
3797 positions.push_back(pos);
3798 }
3799 return positions;
3800}
3801
3802/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3803SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3804 AffineExpr d0, d1, d2;
3805 SmallVector<AffineMap> indexingMaps;
3806 bindDims(context, d0, d1, d2);
3807 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3808 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3809 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3810 return indexingMaps;
3811}
3812
3813bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3814 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3815 if (!maps)
3816 return false;
3817 if (maps.size() != 3)
3818 return false;
3819 auto positions = getAffineResultPositions(maps);
3820 if (failed(positions))
3821 return false;
3822 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3823 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3824 (*positions)[2] == SmallVector<int64_t>{0, 1};
3825}
3826
3827SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3828 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3829 utils::IteratorType::parallel,
3830 utils::IteratorType::reduction};
3831}
3832
3833unsigned MatmulOp::getNumRegionArgs() { return 3; }
3834
3835std::string MatmulOp::getLibraryCallName() {
3836 return generateLibraryCallName(getOperation());
3837}
3838
3839bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3840
3841/// Check if the op has broadcast and/or transpose semantic. Returns true if
3842/// the user defined indexing maps are not equal to default map.
3843bool MatmulOp::hasUserDefinedMaps() {
3844 SmallVector<AffineMap, 3> defaultMaps =
3845 getDefaultIndexingMaps(this->getContext());
3846 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3847 return defaultMaps != explicitMaps;
3848}
3849
3850/// Implements the block region builder for the MatmulOp. This is called by
3851/// 'fillStructuredOpRegion'.
3852void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3853 ArrayRef<NamedAttribute> attrs,
3854 function_ref<InFlightDiagnostic()> emitError) {
3855 if (emitError && block.getNumArguments() != 3) {
3856 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3857 << block.getNumArguments();
3858 return;
3859 }
3860 assert(block.getNumArguments() == 3 &&
3861 "MatmulOp regionBuilder expects 3 args");
3862 RegionBuilderHelper helper(b, block);
3863 SmallVector<Value> yields;
3864
3865 TypeFn castVal = TypeFn::cast_signed;
3866 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3867 return attr.getName() == "cast";
3868 });
3869 if (castIter != attrs.end()) {
3870 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3871 castVal = attr.getValue();
3872 }
3873
3874 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3875 block.getArgument(0));
3876 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3877 block.getArgument(1));
3878 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3879 if (!value1 || !value2 || !value3)
3880 return;
3881 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3882 value3, emitError);
3883 if (!value4)
3884 return;
3885 yields.push_back(value4);
3886 helper.yieldOutputs(yields);
3887}
3888
3889/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3890/// broadcast map must include K dimension.
3891/// TODO: Strict inclusion of K dimension in the broadcast map is not
3892/// necessary for both input matrices simultaneously. We can relax this
3893/// condition to have K dimension for one input matrix map and infer the K
3894/// dimension for other input matrix map from the one already having K
3895/// dimension.
3896bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3897 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3898 AffineExpr expr = bcastMap.getResult(0);
3899 // Invalid map if the common dimension of matmul not found.
3900 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3901}
3902
3903static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3904 if (parser.parseOptionalKeyword("indexing_maps"))
3905 return ArrayAttr{
3906 nullptr}; // Success in case indexing_maps was not provided.
3907
3908 ArrayAttr arrayAttr;
3909 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3910 return failure();
3911
3912 if (llvm::any_of(arrayAttr,
3913 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3914 return parser.emitError(parser.getCurrentLocation())
3915 << "element of indexing_maps array is not an affine_map";
3916
3917 return arrayAttr;
3918}
3919
3920ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3921 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3922 if (failed(indexingMapsAttr))
3923 return failure();
3924
3925 if (*indexingMapsAttr == nullptr) {
3926 auto indexingMapAttrs = llvm::map_to_vector(
3927 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3928 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3929 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3930 }
3931
3932 result.addAttribute("indexing_maps", *indexingMapsAttr);
3933 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3934 MatmulOp::getRegionBuilder());
3935}
3936
3937void MatmulOp::print(OpAsmPrinter &p) {
3938 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3939 MatmulOp::getDefaultIndexingMaps(getContext()),
3940 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3941 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3942 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3943
3944 std::array<StringRef, 3> elidedAttrs = {
3945 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3946 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3947 elidedAttrs);
3948}
3949
3950/// Verify the user defined indexing maps.
3951LogicalResult MatmulOp::verify() {
3952 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3953 if (!hasUserDefinedMaps())
3954 return success();
3955
3956 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3957 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3958 return failure();
3959 }
3960 return success();
3961}
3962
3963LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3964 return memref::foldMemRefCast(*this);
3965}
3966
3967void MatmulOp::getEffects(
3968 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3969 &effects) {
3970 if (hasPureTensorSemantics())
3971 return;
3972 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3973}
3974
3975Speculation::Speculatability MatmulOp::getSpeculatability() {
3976 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3977}
3978
3979SmallVector<AffineMap>
3980MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3981 AffineExpr d0, d1, d2;
3982 MLIRContext *context = builder.getContext();
3983 bindDims(context, d0, d1, d2);
3984 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3985 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3986 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3987 return {mapLHS, mapRHS, mapOut};
3988}
3989
3991 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3992 if (!maps)
3993 return false;
3994 if (maps.size() != 3)
3995 return false;
3996 auto positions = getAffineResultPositions(maps);
3997 if (failed(positions))
3998 return false;
3999 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4000 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4001 (*positions)[2] == SmallVector<int64_t>{0, 1};
4002}
4003
4006 ValueRange inputs, ValueRange outputs,
4007 ArrayRef<NamedAttribute> attributes) {
4008 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4009 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4010}
4011
4014 ValueRange inputs, ValueRange outputs,
4015 ArrayRef<NamedAttribute> attributes) {
4016 OperationState state(location, getOperationName());
4017 build(builder, state, inputs, outputs, attributes);
4018 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4019 assert(res && "builder didn't return the right type");
4020 return res;
4021}
4022
4025 TypeRange resultTensorTypes,
4026 ValueRange inputs, ValueRange outputs,
4027 ArrayRef<NamedAttribute> attributes) {
4028 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4029 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4030}
4031
4034 TypeRange resultTensorTypes, ValueRange inputs,
4035 ValueRange outputs,
4036 ArrayRef<NamedAttribute> attributes) {
4037 OperationState state(location, getOperationName());
4038 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4039 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4040 assert(res && "builder didn't return the right type");
4041 return res;
4042}
4043
4046 TypeRange resultTensorTypes,
4047 ValueRange inputs, ValueRange outputs,
4048 Attribute cast,
4049 ArrayRef<NamedAttribute> attributes) {
4050 result.addAttribute("cast", cast);
4051 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4052 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4053}
4054
4057 TypeRange resultTensorTypes, ValueRange inputs,
4058 ValueRange outputs, Attribute cast,
4059 ArrayRef<NamedAttribute> attributes) {
4060 OperationState state(location, getOperationName());
4061 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4062 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4063 assert(res && "builder didn't return the right type");
4064 return res;
4065}
4066
4068 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4070 op->getAttr("indexing_maps"));
4071}
4072
4074MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4075 AffineExpr d0, d1, d2;
4076 MLIRContext *context = builder.getContext();
4077 bindDims(context, d0, d1, d2);
4078 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4079 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4080 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4081 return {mapLHS, mapRHS, mapOut};
4082}
4083
4085 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4086 if (!maps)
4087 return false;
4088 if (maps.size() != 3)
4089 return false;
4090 auto positions = getAffineResultPositions(maps);
4091 if (failed(positions))
4092 return false;
4093 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4094 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4095 (*positions)[2] == SmallVector<int64_t>{0, 1};
4096}
4097
4100 ValueRange inputs, ValueRange outputs,
4101 ArrayRef<NamedAttribute> attributes) {
4102 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4103 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4104}
4105
4108 ValueRange inputs, ValueRange outputs,
4109 ArrayRef<NamedAttribute> attributes) {
4110 OperationState state(location, getOperationName());
4111 build(builder, state, inputs, outputs, attributes);
4112 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4113 assert(res && "builder didn't return the right type");
4114 return res;
4115}
4116
4119 TypeRange resultTensorTypes,
4120 ValueRange inputs, ValueRange outputs,
4121 ArrayRef<NamedAttribute> attributes) {
4122 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4123 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4124}
4125
4128 TypeRange resultTensorTypes, ValueRange inputs,
4129 ValueRange outputs,
4130 ArrayRef<NamedAttribute> attributes) {
4131 OperationState state(location, getOperationName());
4132 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4133 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4134 assert(res && "builder didn't return the right type");
4135 return res;
4136}
4137
4140 TypeRange resultTensorTypes,
4141 ValueRange inputs, ValueRange outputs,
4142 Attribute cast,
4143 ArrayRef<NamedAttribute> attributes) {
4144 result.addAttribute("cast", cast);
4145 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4146 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4147}
4148
4151 TypeRange resultTensorTypes, ValueRange inputs,
4152 ValueRange outputs, Attribute cast,
4153 ArrayRef<NamedAttribute> attributes) {
4154 OperationState state(location, getOperationName());
4155 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4156 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4157 assert(res && "builder didn't return the right type");
4158 return res;
4159}
4160
4162 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4164 op->getAttr("indexing_maps"));
4165}
4166
4168BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4169 AffineExpr d0, d1, d2, d3;
4170 MLIRContext *context = builder.getContext();
4171 bindDims(context, d0, d1, d2, d3);
4172 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4173 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4174 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4175 return {mapLHS, mapRHS, mapOut};
4176}
4177
4179 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4180 if (!maps)
4181 return false;
4182 if (maps.size() != 3)
4183 return false;
4184 auto positions = getAffineResultPositions(maps);
4185 if (failed(positions))
4186 return false;
4187 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4188 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4189 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4190}
4191
4193 OpBuilder &builder, OperationState &result, ValueRange inputs,
4194 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4195 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4196 BatchMatmulOp::getRegionBuilder(),
4197 getDefaultIndexingMaps(builder));
4198}
4199
4202 ValueRange inputs, ValueRange outputs,
4203 ArrayRef<NamedAttribute> attributes) {
4204 OperationState state(location, getOperationName());
4205 build(builder, state, inputs, outputs, attributes);
4206 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4207 assert(res && "builder didn't return the right type");
4208 return res;
4209}
4210
4212 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4213 ValueRange inputs, ValueRange outputs,
4214 ArrayRef<NamedAttribute> attributes) {
4215 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4216 BatchMatmulOp::getRegionBuilder(),
4217 getDefaultIndexingMaps(builder));
4218}
4219
4222 TypeRange resultTensorTypes, ValueRange inputs,
4223 ValueRange outputs,
4224 ArrayRef<NamedAttribute> attributes) {
4225 OperationState state(location, getOperationName());
4226 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4227 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4228 assert(res && "builder didn't return the right type");
4229 return res;
4230}
4231
4233 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4234 ValueRange inputs, ValueRange outputs, Attribute cast,
4235 ArrayRef<NamedAttribute> attributes) {
4236 result.addAttribute("cast", cast);
4237 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4238 BatchMatmulOp::getRegionBuilder(),
4239 getDefaultIndexingMaps(builder));
4240}
4241
4244 TypeRange resultTensorTypes, ValueRange inputs,
4245 ValueRange outputs, Attribute cast,
4246 ArrayRef<NamedAttribute> attributes) {
4247 OperationState state(location, getOperationName());
4248 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4249 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4250 assert(res && "builder didn't return the right type");
4251 return res;
4252}
4253
4255 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4257 op->getAttr("indexing_maps"));
4258}
4259
4261BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4262 AffineExpr d0, d1, d2, d3;
4263 MLIRContext *context = builder.getContext();
4264 bindDims(context, d0, d1, d2, d3);
4265 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4266 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4267 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4268 return {mapLHS, mapRHS, mapOut};
4269}
4270
4272 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4273 if (!maps)
4274 return false;
4275 if (maps.size() != 3)
4276 return false;
4277 auto positions = getAffineResultPositions(maps);
4278 if (failed(positions))
4279 return false;
4280 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4281 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4282 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4283}
4284
4286 OpBuilder &builder, OperationState &result, ValueRange inputs,
4287 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4288 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4289 BatchMatmulOp::getRegionBuilder(),
4290 getDefaultIndexingMaps(builder));
4291}
4292
4295 ValueRange inputs, ValueRange outputs,
4296 ArrayRef<NamedAttribute> attributes) {
4297 OperationState state(location, getOperationName());
4298 build(builder, state, inputs, outputs, attributes);
4299 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4300 assert(res && "builder didn't return the right type");
4301 return res;
4302}
4303
4305 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4306 ValueRange inputs, ValueRange outputs,
4307 ArrayRef<NamedAttribute> attributes) {
4308 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4309 BatchMatmulOp::getRegionBuilder(),
4310 getDefaultIndexingMaps(builder));
4311}
4312
4315 TypeRange resultTensorTypes, ValueRange inputs,
4316 ValueRange outputs,
4317 ArrayRef<NamedAttribute> attributes) {
4318 OperationState state(location, getOperationName());
4319 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4320 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4321 assert(res && "builder didn't return the right type");
4322 return res;
4323}
4324
4326 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4327 ValueRange inputs, ValueRange outputs, Attribute cast,
4328 ArrayRef<NamedAttribute> attributes) {
4329 result.addAttribute("cast", cast);
4330 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4331 BatchMatmulOp::getRegionBuilder(),
4332 getDefaultIndexingMaps(builder));
4333}
4334
4337 TypeRange resultTensorTypes, ValueRange inputs,
4338 ValueRange outputs, Attribute cast,
4339 ArrayRef<NamedAttribute> attributes) {
4340 OperationState state(location, getOperationName());
4341 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4342 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4343 assert(res && "builder didn't return the right type");
4344 return res;
4345}
4346
4348 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4350 op->getAttr("indexing_maps"));
4351}
4352
4353//===----------------------------------------------------------------------===//
4354// ContractOp
4355//===----------------------------------------------------------------------===//
4356
4357SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4358 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4359 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4360 // domains are all the same, and each implements a projected permutation.
4361 // Each iteration space dim must occur for at least one operand and either
4362 // takes part in a contraction/reduction or else has parallel iteration type.
4363 // We have that a dim is a contraction/reduction dim if and only if the dim
4364 // occurs for the output operand. We use this fact for fast inference:
4365 // NB: In case we allow dims to occur solely for one input, the above still
4366 // holds: per the einsum semantics, these are reduction dims as well.
4367 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4368 for (auto result : outAffineMap.getResults()) {
4369 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4370 assert(dimExpr && "affine_map is a projected permutation");
4371 dimsInOutput[dimExpr.getPosition()] = true;
4372 }
4373
4375 for (auto dimOccursInOutput : dimsInOutput)
4376 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4377 : utils::IteratorType::reduction);
4378
4379 return iteratorTypes;
4380}
4381
4382unsigned ContractOp::getNumRegionArgs() { return 3; }
4383
4384/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4385void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4386 ArrayRef<NamedAttribute> attrs,
4387 function_ref<InFlightDiagnostic()> emitError) {
4388 if (emitError && block.getNumArguments() != 3) {
4389 emitError() << "ContractOp regionBuilder expects 3 args, got "
4390 << block.getNumArguments();
4391 return;
4392 }
4393 assert(block.getNumArguments() == 3 &&
4394 "ContractOp regionBuilder expects 3 args");
4395 RegionBuilderHelper helper(b, block);
4396
4397 TypeFn castSignedness = TypeFn::cast_signed;
4398 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4399 return attr.getName() == "cast";
4400 });
4401 if (castIter != attrs.end()) {
4402 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4403 castSignedness = attr.getValue();
4404 }
4405
4406 // TODO: Support fields with operators besides mult & add.
4407 Type outType = block.getArgument(2).getType();
4408 Value lhsAtOutType =
4409 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4410 Value rhsAtOutType =
4411 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4412 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4413 rhsAtOutType, emitError);
4414 if (!productAtOutType)
4415 return;
4416 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4417 productAtOutType, emitError);
4418 if (!result)
4419 return;
4420 helper.yieldOutputs({result});
4421}
4422
4423ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4424 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4425 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4426 return parser.emitError(parser.getCurrentLocation(),
4427 "expected 'indexing_maps' attribute");
4428 result.addAttribute("indexing_maps", *indexingMapsAttr);
4429
4430 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4431 regionBuilder);
4432}
4433
4434void ContractOp::print(OpAsmPrinter &p) {
4435 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4437 p, getOperation(), getInputs(), getOutputs(),
4438 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4439}
4440
4441LogicalResult ContractOp::verify() {
4442 int iterationSpaceDims = -1;
4443 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4444 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4445 // access an input operand (so occurrence count can be at most 2) and
4446 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4447 SmallVector<size_t> inOccurrences;
4448 SmallVector<size_t> outOccurrences;
4449
4450 // A helper so that for each operand's affine_map and type we check that ...
4451 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4452 bool isInput) -> LogicalResult {
4453 // ... the affine_map is a projected permutation;
4454 if (!affineMap.isProjectedPermutation())
4455 return emitError("provided affine_map is not a projected permutation");
4456
4457 // ... the rank of the affine_map's results and corresponding type match;
4458 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4459 if (affineMap.getNumResults() != shapedType.getRank())
4460 return emitError("ranks of shaped operand and results of corresponding "
4461 "affine_map differ");
4462 } else if (affineMap.getNumResults() != 0) {
4463 return emitError("affine_map specifies shaped access while operand has "
4464 "non-shaped type");
4465 }
4466
4467 // ... the rank of the affine_map's domain is the same as those seen prior;
4468 if (iterationSpaceDims == -1) {
4469 iterationSpaceDims = affineMap.getNumDims();
4470 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4471 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4472 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4473 return emitError("iteration spaces of provided affine_maps differ");
4474 }
4475
4476 // ... update counts of dims used to access either an input or the output.
4477 for (AffineExpr affineExpr : affineMap.getResults()) {
4478 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4479 if (!affineDimExpr)
4480 llvm_unreachable("affine_map is a projected permutation");
4481
4482 if (isInput)
4483 inOccurrences[affineDimExpr.getPosition()] += 1;
4484 else
4485 outOccurrences[affineDimExpr.getPosition()] += 1;
4486 }
4487
4488 return success();
4489 };
4490
4491 for (auto &&[affineMap, operandType, isInput] :
4492 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4493 SmallVector<bool>{true, true, false})) {
4494 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4495 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4496 }
4497
4498 bool hasContractingDim = false;
4499 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4500 size_t inOccCount = inOccurrences[dimIndex];
4501 size_t outOccCount = outOccurrences[dimIndex];
4502
4503 // We have a contracting dim if and only if ...
4504 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4505
4506 if (inOccCount == 0 && outOccCount == 0)
4507 return emitError() << "iteration space dim at index " << dimIndex
4508 << " not used to access any operand";
4509
4510 // NB: We disallow a dim which occurs for only one input operand and not
4511 // for the output. In terms of einsum semantics such dims have a
4512 // sensible meaning - namely an additional reduction per each such dim.
4513 // By contrast, the ContractionOpInterface does not know about this
4514 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4515 // while vector.contract's verifier accepts dims of this kind many of
4516 // its lowerings give up on encountering these dims.
4517 // TODO: Remove following once we have comprehensive support for input-only
4518 // reduction dims, at both the linalg- and vector-dialect levels.
4519 if (inOccCount == 1 && outOccCount != 1)
4520 return emitError()
4521 << "iteration space dim at index " << dimIndex
4522 << " is neither a contracting dim nor of parallel iteration type";
4523 }
4524
4525 if (!hasContractingDim)
4526 return emitError("'indexing_maps' do not specify a contracting dimension");
4527
4528 return success();
4529}
4530
4531LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4532 return memref::foldMemRefCast(*this);
4533}
4534
4535void ContractOp::getEffects(
4536 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4537 &effects) {
4538 if (hasPureTensorSemantics())
4539 return;
4540 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4541}
4542
4543Speculation::Speculatability ContractOp::getSpeculatability() {
4544 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4545}
4546
4547//===----------------------------------------------------------------------===//
4548// Implementation of BatchMatmulOp
4549//===----------------------------------------------------------------------===//
4550SmallVector<AffineMap>
4551BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4552 AffineExpr d0, d1, d2, d3;
4553 SmallVector<AffineMap> indexingMaps;
4554 bindDims(context, d0, d1, d2, d3);
4555 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4556 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4557 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4558 return indexingMaps;
4559}
4560
4561bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4562 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4563 if (!maps)
4564 return false;
4565 if (maps.size() != 3)
4566 return false;
4567 auto positions = getAffineResultPositions(maps);
4568 if (failed(positions))
4569 return false;
4570 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4571 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4572 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4573}
4574
4575SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4576 return SmallVector<utils::IteratorType>{
4577 utils::IteratorType::parallel, utils::IteratorType::parallel,
4578 utils::IteratorType::parallel, utils::IteratorType::reduction};
4579}
4580
4581unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4582
4583std::string BatchMatmulOp::getLibraryCallName() {
4584 return generateLibraryCallName(getOperation());
4585}
4586
4587/// Check if the op has broadcast and/or transpose semantic. Returns true if
4588/// the user defined indexing maps are not equal to default map.
4589bool BatchMatmulOp::hasUserDefinedMaps() {
4590 SmallVector<AffineMap, 3> defaultMaps =
4591 getDefaultIndexingMaps(this->getContext());
4592 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4593 return defaultMaps != explicitMaps;
4594}
4595
4596/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4597/// broadcast map must include K dimension.
4598/// TODO: Strict inclusion of K dimension in the broadcast map is not
4599/// necessary for both input matrices simultaneously. We can relax this
4600/// condition to have K dimension for one input matrix map and infer the K
4601/// dimension for other input matrix map from the one already having K
4602/// dimension.
4603bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4604 assert(bcastMap.getNumResults() < 3 &&
4605 "Expected less than 3 result dim expr.");
4606 bool isValid = false;
4607 enum Indices { batchPos, mPos, nPos, kPos };
4608 if (bcastMap.getNumResults() == 1) {
4609 AffineExpr expr = bcastMap.getResult(0);
4610 isValid = expr.isFunctionOfDim(kPos);
4611 } else if (bcastMap.getNumResults() == 2) {
4612 AffineExpr expr0 = bcastMap.getResult(0);
4613 AffineExpr expr1 = bcastMap.getResult(1);
4614 isValid =
4615 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4616 expr0.isFunctionOfDim(mPos)) &&
4617 expr1.isFunctionOfDim(kPos))
4618 : ((expr0.isFunctionOfDim(batchPos) &&
4619 expr1.isFunctionOfDim(kPos)) ||
4620 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4621 }
4622 return isValid;
4623}
4624
4625void BatchMatmulOp::regionBuilder(
4626 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4627 function_ref<InFlightDiagnostic()> emitError) {
4628 if (emitError && block.getNumArguments() != 3) {
4629 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4630 << block.getNumArguments();
4631 return;
4632 }
4633 assert(block.getNumArguments() == 3 &&
4634 "BatchMatmulOp regionBuilder expects 3 args");
4635 RegionBuilderHelper helper(b, block);
4636 SmallVector<Value> yields;
4637
4638 TypeFn castVal = TypeFn::cast_signed;
4639 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4640 return attr.getName() == "cast";
4641 });
4642 if (castIter != attrs.end()) {
4643 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4644 castVal = attr.getValue();
4645 }
4646
4647 auto toType = block.getArgument(2).getType();
4648 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4649 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4650 Value mulVal =
4651 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
4652 if (!castValA || !castValB || !mulVal)
4653 return;
4654 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4655 mulVal, emitError);
4656 if (!addVal)
4657 return;
4658 yields.push_back(addVal);
4659 helper.yieldOutputs(yields);
4660}
4661
4662ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4663 SmallVector<Attribute, 3> indexingMapsAttr;
4664 Attribute mapAttr;
4665 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4666 if (parser.parseEqual())
4667 return failure();
4668
4669 if (parser.parseLSquare())
4670 return failure();
4671
4672 do {
4673 if (parser.parseAttribute(mapAttr))
4674 return failure();
4675 if (!isa<AffineMapAttr>(mapAttr)) {
4676 return parser.emitError(parser.getCurrentLocation(),
4677 "expected affine map attribute");
4678 }
4679 indexingMapsAttr.push_back(mapAttr);
4680
4681 if (parser.parseOptionalComma())
4682 break;
4683 } while (true);
4684
4685 if (parser.parseRSquare())
4686 return failure();
4687 }
4688 // Initialize indexingMaps, if not supplied explicitly.
4689 if (indexingMapsAttr.empty()) {
4690 indexingMapsAttr = llvm::map_to_vector(
4691 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4692 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4693 }
4694 result.addAttribute("indexing_maps",
4695 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4696
4697 return ::parseNamedStructuredOp(parser, result,
4698 BatchMatmulOp::getNumRegionArgs(),
4699 BatchMatmulOp::getRegionBuilder());
4700}
4701
4702void BatchMatmulOp::print(OpAsmPrinter &p) {
4703 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4704 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4705 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4706 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4707 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4708
4709 std::array<StringRef, 3> elidedAttrs = {
4710 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4711 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4712 elidedAttrs);
4713}
4714
4715/// Verify the user defined indexing maps.
4716LogicalResult BatchMatmulOp::verify() {
4717 // Verification of pure batch_matmul is handled by
4718 // verifyStructuredOpInterface().
4719 if (!hasUserDefinedMaps())
4720 return success();
4721
4722 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4724 return failure();
4725 }
4726 return success();
4727}
4728
4729LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4730 SmallVectorImpl<OpFoldResult> &) {
4731 return memref::foldMemRefCast(*this);
4732}
4733
4734void BatchMatmulOp::getEffects(
4735 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4736 &effects) {
4737 if (hasPureTensorSemantics())
4738 return;
4739 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4740}
4741
4742Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4743 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4744}
4745
4746//===----------------------------------------------------------------------===//
4747// ElementwiseOp
4748//===----------------------------------------------------------------------===//
4749//
4750namespace {
4751struct ArityGroupAndKind {
4752 // The enum class {Unary, Binary, Ternary, ..}
4753 ElementwiseArityGroup arityGroup;
4754
4755 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4756 union Kind {
4757 UnaryFn unaryFn;
4758 BinaryFn binaryFn;
4759 TernaryFn ternaryFn;
4760 } kind;
4761};
4762
4763unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4764 return static_cast<unsigned>(arityGroup);
4765}
4766} // namespace
4767
4768static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4769 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4770 constexpr int lastBinary =
4771 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4772 constexpr int lastTernary =
4773 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4774
4775 int val = static_cast<int>(kind);
4776 ArityGroupAndKind result;
4777
4778 if (val < lastUnary) {
4779 result.arityGroup = ElementwiseArityGroup::Unary;
4780 result.kind.unaryFn = static_cast<UnaryFn>(val);
4781 return result;
4782 }
4783 if (val < lastBinary) {
4784 result.arityGroup = ElementwiseArityGroup::Binary;
4785 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4786 return result;
4787 }
4788 if (val >= lastTernary) {
4789 llvm_unreachable("unhandled ElementwiseFn");
4790 }
4791 result.arityGroup = ElementwiseArityGroup::Ternary;
4792 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4793 return result;
4794}
4795
4796SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4797 auto rank = getResultRank();
4798 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4799}
4800
4802ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4803 MLIRContext *context) {
4804 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4805 return SmallVector<AffineMap>(numMaps, map);
4806}
4807
4808ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4809 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4810 Attribute attr;
4811 mlir::linalg::ElementwiseKind elemwiseKindVal;
4812 if (parser.parseKeyword("kind") || parser.parseEqual())
4813 return failure();
4814
4815 if (succeeded(parser.parseAttribute(attr))) {
4816 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4817 if (!elemwiseKindAttr)
4818 return parser.emitError(parser.getCurrentLocation(),
4819 "expected ElementwiseKind attribute");
4820 elemwiseKindVal = elemwiseKindAttr.getValue();
4821 } else {
4822 return parser.emitError(parser.getCurrentLocation(),
4823 "expected operation 'kind' attribute");
4824 }
4825 result.addAttribute(
4826 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4827
4828 // Parse optional `indexing_maps`
4829 SmallVector<Attribute, 3> indexingMapsAttr;
4830 Attribute mapAttr;
4831 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4832 if (parser.parseEqual())
4833 return failure();
4834 if (parser.parseLSquare())
4835 return failure();
4836 do {
4837 if (parser.parseAttribute(mapAttr))
4838 return failure();
4839 if (!isa<AffineMapAttr>(mapAttr))
4840 return parser.emitError(parser.getCurrentLocation(),
4841 "expected affine map attribute");
4842 indexingMapsAttr.push_back(mapAttr);
4843 if (parser.parseOptionalComma())
4844 break;
4845 } while (true);
4846 if (parser.parseRSquare())
4847 return failure();
4848 }
4849 // At this stage of parsing the only way to infer number of region
4850 // args is through op kind, as input output tensors are not parsed yet.
4851 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4852 int numRegionArgs =
4853 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4854 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4855 ElementwiseOp::getRegionBuilder())) {
4856 return parser.emitError(parser.getCurrentLocation(),
4857 "unable to parse elemwise op");
4858 }
4859
4860 // Initialize indexingMaps, if not supplied explicitly.
4861 if (indexingMapsAttr.empty()) {
4862 // We need to infer the numDims of the indexing maps from the output
4863 // type which is already parsed by now.
4864 auto resultType = result.operands[result.operands.size() - 1].getType();
4865 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4866 if (!shapedType)
4867 return parser.emitError(parser.getCurrentLocation(),
4868 "return type needs to be shaped type");
4869 auto numDims = shapedType.getRank();
4870 indexingMapsAttr = llvm::map_to_vector(
4871 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4872 parser.getContext()),
4873 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4874 }
4875
4876 result.addAttribute("indexing_maps",
4877 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4878 return success();
4879}
4880
4881void ElementwiseOp::print(OpAsmPrinter &p) {
4882 p << " kind=";
4883 p.printAttribute(getKindAttr());
4884 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4885 "indexing_maps"};
4886 unsigned arity =
4887 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4888 unsigned numDims = getResultRank();
4889
4890 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4891 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4892 getContext()),
4893 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4894
4895 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4896 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4897
4898 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4899 elidedAttrs);
4900}
4901
4902/// Implements the block region builder for the ElementwiseOp. This is called by
4903/// 'fillStructuredOpRegion'.
4904void ElementwiseOp::regionBuilder(
4905 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4906 function_ref<InFlightDiagnostic()> emitError) {
4907 ElementwiseKind elemwiseKind;
4908 for (auto attr : attrs) {
4909 if (attr.getName() == b.getStringAttr("kind")) {
4910 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4911 assert(kindAttr && "op kind attribute incorrectly set");
4912 elemwiseKind = kindAttr.getValue();
4913 break;
4914 }
4915 }
4916
4917 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4918 auto arityGroup = groupAndKind.arityGroup;
4919 auto kind = groupAndKind.kind;
4920 if (emitError && block.getNumArguments() !=
4921 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4922 emitError() << "Elementwise regionBuilder expects "
4923 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4924 << block.getNumArguments();
4925 return;
4926 }
4927 assert(block.getNumArguments() ==
4928 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4929 && "Elementwise regionBuilder number of block args mismatch");
4930
4931 RegionBuilderHelper helper(b, block);
4932 SmallVector<Value> yields;
4933 Value result;
4934
4935 if (arityGroup == ElementwiseArityGroup::Unary) {
4936 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4937
4938 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4939 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4940 block.getArgument(1));
4941
4942 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4943 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4944 block.getArgument(1), block.getArgument(2));
4945
4946 } else {
4947 assert(false && "found unhandled category in elemwise");
4948 }
4949
4950 yields.push_back(result);
4951 helper.yieldOutputs(yields);
4952}
4953
4954LogicalResult ElementwiseOp::fold(FoldAdaptor,
4955 SmallVectorImpl<OpFoldResult> &) {
4956 return memref::foldMemRefCast(*this);
4957}
4958
4959void ElementwiseOp::getEffects(
4960 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4961 &effects) {
4962 if (hasPureTensorSemantics())
4963 return;
4964 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4965}
4966
4967Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4968 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4969}
4970
4971//===----------------------------------------------------------------------===//
4972// PackOp/UnPackOp Common
4973//===----------------------------------------------------------------------===//
4974
4975template <typename OpTy, typename>
4976SmallVector<int64_t>
4978 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4979 ? packOrUnPack.getDestType()
4980 : packOrUnPack.getSourceType();
4981 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4982 ? packOrUnPack.getSourceType()
4983 : packOrUnPack.getDestType();
4985 packedType.getShape().take_front(unpackedType.getRank()));
4986 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4988 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4989 }
4990 return result;
4991}
4996
4997// Given the (potentially) updated packed type, `newPackedTy`, generates an
4998// updated mixed-tile-sizes attribute. A tile size is updated only
4999// when:
5000// * a dim from newPackedTy is static, and
5001// * the corresponding size from mixedTiles is still dynamic.
5002// Otherwise, the original tile size is preserved.
5003// Note - packed-type-dim and mixed-tile-size should always match!
5006 SmallVector<OpFoldResult> mixedTiles) {
5007 SmallVector<OpFoldResult> newMixedTileSizes;
5008 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5009 .getShape()
5010 .take_back(mixedTiles.size()),
5011 mixedTiles)) {
5012 int64_t shape = std::get<0>(it);
5013 if (shape == ShapedType::kDynamic) {
5014 newMixedTileSizes.push_back(std::get<1>(it));
5015 continue;
5016 }
5017
5018 // If the current result dim is static, update the dynamic mixed-size
5019 // (provided the original value is dynamic).
5020 OpFoldResult tile = std::get<1>(it);
5021 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5022 // Already a constant
5023 newMixedTileSizes.push_back(tile);
5024 } else {
5025 assert(getConstantIntValue(tile).value() == shape &&
5026 "tile size and dim size don't match!");
5027 newMixedTileSizes.push_back(
5028 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5029 }
5030 }
5031
5032 return newMixedTileSizes;
5033}
5034
5035template <typename OpTy>
5036static LogicalResult
5038 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5039 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5040 "applies to only pack or unpack operations");
5041 int64_t destRank = op.getDestRank();
5042 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5043 for (auto dim : llvm::seq<int64_t>(0, destRank))
5044 reifiedReturnShapes[0][dim] =
5045 createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
5046 return success();
5047}
5048
5049template <typename OpTy>
5051 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5052 "applies to only pack or unpack operations");
5053 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5054 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5055 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5056 assert(tiles.size() == dimsToTile.size() &&
5057 "tiles must match indices of dimension to block");
5058 // bind the dimension `i` with the tile factor.
5059 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5060 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5061 return dimAndTileMapping;
5062}
5063
5064template <typename OpTy>
5066 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5067 "applies to only pack or unpack operations");
5068 Builder builder(op);
5069 SmallVector<OpFoldResult> mixedInnerTiles;
5070 unsigned dynamicValIndex = 0;
5071 for (int64_t staticTile : op.getStaticInnerTiles()) {
5072 if (ShapedType::isStatic(staticTile))
5073 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5074 else
5075 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5076 }
5077 return mixedInnerTiles;
5078}
5079
5080template <typename OpTy>
5082 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5083 "applies to only pack or unpack operations");
5084 SmallVector<Value> dynamicTiles;
5085 SmallVector<int64_t> staticTiles;
5086 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5087 return staticTiles;
5088}
5089
5090/// Returns true if `dimsPos` is invalid. It is invalid when:
5091/// a) It contains duplicate.
5092/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5093/// c) The number of elements in `dimsPos` is > than `rank`.
5095 size_t rank) {
5096 size_t dimsPosSize = dimsPos.size();
5097 if (dimsPosSize > rank)
5098 return true;
5099 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5100 if (dimsPosSize != uniqued.size())
5101 return true;
5102 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5103 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5104 });
5105}
5106
5107template <typename OpTy>
5108static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5109 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5110 "applies to only pack or unpack operations");
5111 Operation *op = packOrUnPack.getOperation();
5112
5113 // Return true if we have a zero-value tile.
5114 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5115 return llvm::any_of(tiles, isZeroInteger);
5116 };
5117
5118 // Verify that the source and destination are ranked types.
5119 if (!packOrUnPack.getSourceType().hasRank() ||
5120 !packOrUnPack.getDestType().hasRank())
5121 return op->emitError("expected both source and destination to have rank");
5122
5123 // Verify that the Operation does not have mixed tensor/buffer semantics.
5124 if (!packOrUnPack.hasPureBufferSemantics() &&
5125 !packOrUnPack.hasPureTensorSemantics())
5126 return op->emitError("mixing tensor and buffer semantics is not allowed");
5127 const unsigned numResults = packOrUnPack.getNumResults();
5128 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5129 return op->emitError("expected 1 result, got ") << numResults;
5130 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5131 return op->emitError("expected 0 results, got ") << numResults;
5132
5133 // Verify tiles. Do not allow zero tiles.
5134 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5135 if (hasZeros(mixedTiles))
5136 return op->emitError("invalid zero tile factor");
5137
5138 // Verify inner_dims_pos and outer_dims_perm.
5139 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5140 ? packOrUnPack.getSourceType()
5141 : packOrUnPack.getDestType();
5142 size_t unpackedRank = unpackedType.getRank();
5143 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5144 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5145 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5146 return op->emitError("invalid inner_dims_pos vector");
5147 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5148 return op->emitError("invalid outer_dims_perm vector");
5149 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5150 return op->emitError("outer_dims_perm must be a permutation or empty");
5151
5152 // Tiling factors must be less than or equal to the input rank for pack (or
5153 // output rank for unpack), and must match the number of `inner_dims_pos`.
5154 if (mixedTiles.size() > unpackedRank) {
5155 return op->emitError("tiling factors must be less than or equal to the "
5156 "input rank for pack or output rank for unpack");
5157 }
5158 if (mixedTiles.size() != innerDimsPos.size()) {
5159 return op->emitError(
5160 "tiling factors must equal the number of dimensions to tile");
5161 }
5162
5163 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5164 ? packOrUnPack.getDestType()
5165 : packOrUnPack.getSourceType();
5166 size_t packedRank = packedType.getRank();
5167 // Require output rank to match input rank + number of blocking factors.
5168 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5169 if (expectedPackedRank != packedRank) {
5170 return op->emitError(
5171 "packed rank != (unpacked rank + num tiling factors), got ")
5172 << packedRank << " != " << expectedPackedRank;
5173 }
5174
5175 // Verify result shape is greater than the minimum expected
5176 // by the pack operation, and that the output shape
5177 // represents full tiles.
5178 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5179 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5180 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5181 if (!llvm::all_of(
5182 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5183 mixedTiles),
5184 [](std::tuple<int64_t, OpFoldResult> it) {
5185 int64_t shape = std::get<0>(it);
5186 if (Attribute attr =
5187 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5188 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5189 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5190 return shape == staticTileSize;
5191 }
5192 return ShapedType::isDynamic(shape);
5193 })) {
5194 return op->emitError("mismatch in inner tile sizes specified and shaped of "
5195 "tiled dimension in the packed type");
5196 }
5197 if (failed(
5198 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5199 auto elementType = unpackedType.getElementType();
5200 Type expectedType, actualType;
5201 if (packOrUnPack.hasPureTensorSemantics()) {
5202 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5203 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5204 } else {
5205 expectedType = MemRefType::get(expectedPackedShape, elementType);
5206 actualType = MemRefType::get(packedType.getShape(), elementType);
5207 }
5208 return op->emitError("expected ")
5209 << expectedType << " for the packed domain value, got "
5210 << actualType;
5211 }
5212 return success();
5213}
5214
5215namespace {
5216/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5217/// various permutations to the op.
5218// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5219// these. These may or may not become true foldings / canonicalizations
5220// depending on how aggressive we want to be in automatically folding
5221// transposes.
5222struct PackOrUnPackTransposeResult {
5223 SmallVector<int64_t> innerDimsPos;
5224 SmallVector<OpFoldResult> innerTiles;
5225 SmallVector<int64_t> outerDimsPerm;
5226};
5227} // namespace
5228
5229template <typename OpTy>
5230static PackOrUnPackTransposeResult
5232 ArrayRef<int64_t> innerPermutation,
5233 ArrayRef<int64_t> outerPermutation) {
5234 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5235 "applies to only pack or unpack operations");
5236 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5237 "some permutation must be non-empty");
5238 PackOrUnPackTransposeResult metadata;
5239 metadata.innerDimsPos =
5240 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5241 metadata.innerTiles =
5242 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5243 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5244 ? packOrUnPackOp.getSourceRank()
5245 : packOrUnPackOp.getDestRank();
5246 metadata.outerDimsPerm =
5247 packOrUnPackOp.getOuterDimsPerm().empty()
5248 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5249 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5250 if (!innerPermutation.empty()) {
5251 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5252 isPermutationVector(innerPermutation) &&
5253 "invalid inner permutation");
5254 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5255 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5256 }
5257 if (!outerPermutation.empty()) {
5258 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5259 isPermutationVector(outerPermutation) &&
5260 "invalid outer permutation");
5261 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5262 }
5263 return metadata;
5264}
5265
5266//===----------------------------------------------------------------------===//
5267// PackOp
5268//===----------------------------------------------------------------------===//
5269
5270void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5271 if (!getResults().empty())
5272 setNameFn(getResult(), "pack");
5273}
5274
5275ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5276 OpAsmParser::UnresolvedOperand source, dest;
5279 SmallVector<Type> paddingValueType;
5280 SmallVector<int64_t> staticTiles;
5281 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5282 Type sourceType, destType, resultType;
5283
5284 if (parser.parseOperand(source))
5285 return failure();
5286
5287 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5288 if (parser.parseLParen() ||
5289 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5290 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5291 parser.parseRParen())
5292 return failure();
5293 }
5294
5295 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5296 if (parser.parseEqual())
5297 return failure();
5298
5299 SmallVector<int64_t> outerDimsPermVec;
5301 int64_t value;
5302 if (parser.parseInteger(value))
5303 return failure();
5304 outerDimsPermVec.push_back(value);
5305 return success();
5306 }))
5307 return failure();
5308 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5309 }
5310
5311 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5312 return failure();
5313
5314 SmallVector<int64_t> innerDimsPosVec;
5316 int64_t value;
5317 if (parser.parseInteger(value))
5318 return failure();
5319 innerDimsPosVec.push_back(value);
5320 return success();
5321 }))
5322 return failure();
5323 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5324
5325 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5326 return failure();
5327
5328 DenseI64ArrayAttr staticTilesAttr;
5329 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5330 return failure();
5331 for (auto val : staticTilesAttr.asArrayRef())
5332 staticTiles.push_back(val);
5333
5334 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5335 return failure();
5336
5337 if (parser.parseOptionalAttrDict(result.attributes))
5338 return failure();
5339
5340 if (parser.parseColon() || parser.parseType(sourceType))
5341 return failure();
5342
5343 bool hasArrow = succeeded(parser.parseOptionalArrow());
5344 if (hasArrow) {
5345 if (parser.parseType(destType))
5346 return failure();
5347 }
5348
5349 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5350 if (!hasArrow) {
5351 return parser.emitError(parser.getCurrentLocation(),
5352 "pack/unpack requires '->' and destination type");
5353 }
5354
5355 if (!isMemRef)
5356 resultType = destType;
5357
5358 if (parser.resolveOperand(source, sourceType, result.operands) ||
5359 parser.resolveOperand(dest, destType, result.operands))
5360 return failure();
5361
5362 if (!paddingValue.empty() &&
5363 parser.resolveOperands(paddingValue, paddingValueType[0],
5364 result.operands))
5365 return failure();
5366
5367 if (!dynamicTiles.empty() &&
5368 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5369 result.operands))
5370 return failure();
5371
5372 result.addAttribute("static_inner_tiles",
5373 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5374 result.addAttribute("inner_dims_pos", innerDimsPos);
5375 if (outerDimsPerm)
5376 result.addAttribute("outer_dims_perm", outerDimsPerm);
5377
5378 SmallVector<int32_t> segmentSizes = {
5379 1, 1, static_cast<int32_t>(paddingValue.size()),
5380 static_cast<int32_t>(dynamicTiles.size())};
5381 result.addAttribute("operandSegmentSizes",
5382 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5383
5384 if (!isMemRef)
5385 result.addTypes(resultType);
5386
5387 return success();
5388}
5389
5390void PackOp::print(OpAsmPrinter &p) {
5391 p << " " << getSource();
5392
5393 if (getPaddingValue()) {
5394 p << " padding_value(" << getPaddingValue() << " : "
5395 << getPaddingValue().getType() << ")";
5396 }
5397
5398 if (!getOuterDimsPerm().empty()) {
5399 p << " outer_dims_perm = [";
5400 llvm::interleaveComma(getOuterDimsPerm(), p);
5401 p << "]";
5402 }
5403
5404 p << " inner_dims_pos = [";
5405 llvm::interleaveComma(getInnerDimsPos(), p);
5406 p << "]";
5407
5408 p << " inner_tiles = ";
5409 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5410
5411 p << " into " << getDest();
5412
5413 p.printOptionalAttrDict((*this)->getAttrs(),
5414 {"static_inner_tiles", "inner_dims_pos",
5415 "outer_dims_perm", "operandSegmentSizes"});
5416
5417 p << " : " << getSource().getType();
5418 p << " -> " << getDest().getType();
5419}
5420
5421void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5422 Value dest, ArrayRef<int64_t> innerDimsPos,
5423 ArrayRef<OpFoldResult> innerTiles,
5424 std::optional<Value> paddingValue,
5425 ArrayRef<int64_t> outerDimsPerm) {
5426 assert(innerDimsPos.size() == innerTiles.size() &&
5427 "number of tile sizes specified must match the specified number of "
5428 "original dimensions to be tiled");
5429 SmallVector<int64_t> staticTileSizes;
5430 SmallVector<Value> dynamicTileSizes;
5431 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5432 build(builder, state, dest.getType(), source, dest,
5433 paddingValue ? *paddingValue : nullptr,
5434 outerDimsPerm.empty() ? nullptr
5435 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5436 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5437 builder.getDenseI64ArrayAttr(staticTileSizes));
5438}
5439
5440LogicalResult
5441PackOp::reifyResultShapes(OpBuilder &builder,
5442 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5443 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5444}
5445
5446DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5447 return getDimAndTileMappingImpl(*this);
5448}
5449
5450SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5451 return getMixedTilesImpl(*this);
5452}
5453
5454SmallVector<int64_t> PackOp::getStaticTiles() {
5455 return getStaticTilesImpl(*this);
5456}
5457
5458ArrayRef<int64_t> PackOp::getAllOuterDims() {
5459 ShapedType inputType = getSourceType();
5460 int64_t inputRank = inputType.getRank();
5461 return getDestType().getShape().take_front(inputRank);
5462}
5463
5464SmallVector<int64_t> PackOp::getTiledOuterDims() {
5465 auto innerDimsPos = getInnerDimsPos();
5466 SmallVector<int64_t> outerDims(getAllOuterDims());
5468
5469 // Recover the original order of the outer dims.
5470 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5471 invertPermutationVector(outerDimPermInv);
5472 if (!outerDimPermInv.empty())
5473 applyPermutationToVector(outerDims, outerDimPermInv);
5474
5475 // Collect the outer dims corresponding to the tilled inner dims.
5476 for (auto index : innerDimsPos)
5477 res.push_back(outerDims[index]);
5478
5479 return res;
5480}
5481
5482bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5483 ArrayRef<int64_t> innerDimsPos,
5484 ArrayRef<int64_t> outputShape,
5485 ArrayRef<int64_t> outerDimsPerm,
5486 ArrayRef<OpFoldResult> innerTiles) {
5487 SmallVector<int64_t> outputTileSizes(
5488 outputShape.take_front(inputShape.size()));
5489 if (!outerDimsPerm.empty()) {
5490 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5491 "expected output and outer_dims_perm to have same size");
5492 applyPermutationToVector(outputTileSizes,
5493 invertPermutationVector(outerDimsPerm));
5494 }
5495 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5496 if (ShapedType::isDynamic(inputShape[pos]))
5497 continue;
5498 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5499
5500 if (!constantTile) {
5501 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5502 (inputShape[pos] % outputTileSizes[pos] != 0))
5503 return true;
5504 } else if (inputShape[pos] % (*constantTile) != 0) {
5505 return true;
5506 }
5507 }
5508 return false;
5509}
5510
5511bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5512 ArrayRef<int64_t> innerDimsPos,
5513 ArrayRef<int64_t> outputShape,
5514 ArrayRef<int64_t> outerDimsPerm,
5515 ArrayRef<OpFoldResult> innerTiles) {
5516 SmallVector<int64_t> outputTileSizes(
5517 outputShape.take_front(inputShape.size()));
5518 if (!outerDimsPerm.empty()) {
5519 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5520 "expected output and outer_dims_perm to have same size");
5521 applyPermutationToVector(outputTileSizes,
5522 invertPermutationVector(outerDimsPerm));
5523 }
5524 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5525 if (ShapedType::isDynamic(inputShape[pos]) ||
5526 ShapedType::isDynamic(outputTileSizes[pos]))
5527 return true;
5528 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5529 if (!constantTile)
5530 return true;
5531 if (inputShape[pos] % (*constantTile) != 0)
5532 return true;
5533 }
5534 return false;
5535}
5536
5537LogicalResult PackOp::verify() {
5539 return failure();
5540
5541 // Verify padding value, and bail out if the tile does not divide the
5542 // dimension fully. In the case of dynamic tile factors or dimensions, having
5543 // a partial tile is undefined behavior.
5544 auto paddingValue = getPaddingValue();
5545 if (paddingValue &&
5546 paddingValue.getType() != getSourceType().getElementType()) {
5547 return emitOpError("expected padding_value has ")
5548 << getSourceType().getElementType()
5549 << " but got: " << paddingValue.getType();
5550 }
5551
5552 if (!paddingValue &&
5553 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5554 getDestType().getShape(), getOuterDimsPerm(),
5555 getMixedTiles())) {
5556 return emitOpError(
5557 "invalid tile factor or output size provided. Only full tiles are "
5558 "supported when padding_value is not set");
5559 }
5560 return success();
5561}
5562
5563/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5564/// Value's to kDynamic, even if they are arith.constant values.
5568 for (auto o : ofrs) {
5569 // Have to do this first, as getConstantIntValue special-cases constants.
5570 if (llvm::dyn_cast_if_present<Value>(o))
5571 result.push_back(ShapedType::kDynamic);
5572 else
5573 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5574 }
5575 return result;
5576}
5577
5578SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5579 ArrayRef<int64_t> innerTileSizes,
5580 ArrayRef<int64_t> innerDimsPos,
5581 ArrayRef<int64_t> outerDimsPerm) {
5582 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5583 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5584 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5585 continue;
5586 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5587 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5588 continue;
5589 }
5590 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5591 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5592 }
5593
5594 // Swap tile loops if outer_dims_perm is available.
5595 if (!outerDimsPerm.empty())
5596 applyPermutationToVector(resultShape, outerDimsPerm);
5597
5598 // Append the inner tile dimensions.
5599 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5600 return resultShape;
5601}
5602
5603SmallVector<OpFoldResult> PackOp::getResultShape(
5604 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5605 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5606 ArrayRef<int64_t> outerDimsPerm) {
5607 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5608
5609 AffineExpr s0, s1;
5610 bindSymbols(builder.getContext(), s0, s1);
5611 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5612 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5613 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5614 builder, loc, ceilDivExpr,
5615 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5616 }
5617 if (!outerDimsPerm.empty())
5618 applyPermutationToVector(resultDims, outerDimsPerm);
5619 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5620
5621 SmallVector<int64_t> resultTypeShape =
5622 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5623 asShapeWithAnyValueAsDynamic(innerTileSizes),
5624 innerDimsPos, outerDimsPerm);
5625
5626 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5627 // result type shape says it's a dynamic dim. This is needed as callers may
5628 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5629 // dynamic dims returned by that.
5630 for (unsigned i = 0; i < resultDims.size(); ++i) {
5631 if (ShapedType::isStatic(resultTypeShape[i]))
5632 continue;
5633 resultDims[i] =
5634 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5635 }
5636
5637 return resultDims;
5638}
5639
5640RankedTensorType PackOp::inferPackedTensorType(
5641 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5642 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5643 SmallVector<int64_t> resultShape = inferPackedShape(
5644 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5645 return RankedTensorType::get(resultShape, sourceType.getElementType());
5646}
5647
5648MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5649 ArrayRef<int64_t> innerTileSizes,
5650 ArrayRef<int64_t> innerDimsPos,
5651 ArrayRef<int64_t> outerDimsPerm) {
5652 SmallVector<int64_t> resultShape = inferPackedShape(
5653 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5654 return MemRefType::get(resultShape, sourceType.getElementType());
5655}
5656
5657Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5658 ArrayRef<OpFoldResult> innerTileSizes,
5659 ArrayRef<int64_t> innerDimsPos,
5660 ArrayRef<int64_t> outerDimsPerm) {
5661 AffineExpr dim0, dim1;
5662 bindDims(b.getContext(), dim0, dim1);
5663 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5664 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5665 {v1, v2});
5666 };
5667
5668 SmallVector<OpFoldResult> mixedSizes;
5669 for (auto [index, value] : llvm::enumerate(
5670 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5671 if (ShapedType::isDynamic(value))
5672 mixedSizes.push_back(
5673 tensor::DimOp::create(b, loc, source, index).getResult());
5674 else
5675 mixedSizes.push_back(b.getIndexAttr(value));
5676 }
5677 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5678 int64_t dimPos = std::get<0>(it);
5679 OpFoldResult tileSize = std::get<1>(it);
5680 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5681 }
5682 if (!outerDimsPerm.empty())
5683 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5684
5685 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5686 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5687 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5688}
5689
5690PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5691 ArrayRef<int64_t> innerPermutation,
5692 ArrayRef<int64_t> outerPermutation) {
5693 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5694 *this, innerPermutation, outerPermutation);
5695 Value transposedDest =
5696 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5697 metadata.innerDimsPos, metadata.outerDimsPerm);
5698 return PackOp::create(b, loc, getSource(), transposedDest,
5699 metadata.innerDimsPos, metadata.innerTiles,
5700 getPaddingValue(), metadata.outerDimsPerm);
5701}
5702
5703template <typename OpTy>
5706 &effects) {
5707 // No memory effects for pure tensor semantics
5708 if (op.hasPureTensorSemantics())
5709 return;
5710
5711 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5712 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5713 continue;
5714
5715 if (&opOperand == &op.getSourceMutable()) {
5716 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5717 /*effectOnFullRegion=*/true,
5719 } else if (&opOperand == &op.getDestMutable()) {
5720 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5721 /*effectOnFullRegion=*/true,
5723 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5724 /*effectOnFullRegion=*/true,
5726 }
5727 }
5728}
5729
5730void PackOp::getEffects(
5732 &effects) {
5733 getPackUnPackEffectsImpl(*this, effects);
5734}
5735
5736void UnPackOp::getEffects(
5738 &effects) {
5739 getPackUnPackEffectsImpl(*this, effects);
5740}
5741
5742/// Returns true if the tiles and the tiled dims are constant.
5743template <typename OpTy>
5745 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5746 "applies to only pack or unpack operations");
5747 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5748 ? op.getDestType()
5749 : op.getSourceType();
5750 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5751 for (auto [dimDest, tile] : llvm::zip(
5752 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5753 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5754 if (!constTileSize || ShapedType::isDynamic(dimDest))
5755 return false;
5756 }
5757 return true;
5758}
5759
5760Speculation::Speculatability PackOp::getSpeculatability() {
5761 if (!hasPureTensorSemantics())
5763 if (getPaddingValue())
5765
5766 // The verifier rejects already operations if we can statically prove that the
5767 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5768 // to have constant tiles and tiled inner dimensions.
5771
5773}
5774
5775// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5776// dimensions for pack and unpack.
5777static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5778 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5779 return false;
5780 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5781 return true;
5782 // Outer dims permutation is optional.
5783 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5784 // identity permutation.
5785 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5786 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5787}
5788
5789// Return true if pack and unpack have the same tiles.
5790// Same SSA values or same integer constants.
5791static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5792 auto packTiles = packOp.getMixedTiles();
5793 auto unPackTiles = unPackOp.getMixedTiles();
5794 if (packTiles.size() != unPackTiles.size())
5795 return false;
5796 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5797 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5798 return false;
5799 }
5800 return true;
5801}
5802
5803/// Returns true if the pack op does not need a padding value.
5804static bool paddingIsNotNeeded(PackOp op) {
5805 auto srcType = op.getSourceType();
5806 if (llvm::any_of(op.getInnerDimsPos(),
5807 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5808 return false;
5809 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5810 return false;
5811 return !PackOp::requirePaddingValue(
5812 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5813 op.getOuterDimsPerm(), op.getMixedTiles());
5814}
5815
5816/// Returns true if the `srcShape` or `destShape` is different from the one in
5817/// `packOp` and populates each with the inferred static shape.
5818static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5819 SmallVectorImpl<int64_t> &destShape) {
5820 bool changeNeeded = false;
5821 srcShape.assign(packOp.getSourceType().getShape().begin(),
5822 packOp.getSourceType().getShape().end());
5823 destShape.assign(packOp.getDestType().getShape().begin(),
5824 packOp.getDestType().getShape().end());
5825 llvm::SmallSetVector<int64_t, 4> innerDims;
5826 innerDims.insert_range(packOp.getInnerDimsPos());
5827 SmallVector<int64_t> inverseOuterDimsPerm;
5828 if (!packOp.getOuterDimsPerm().empty())
5829 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5830 int srcRank = packOp.getSourceRank();
5831 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5832 if (innerDims.contains(i))
5833 continue;
5834 int64_t srcPos = i;
5835 int64_t destPos = i;
5836 if (!inverseOuterDimsPerm.empty())
5837 destPos = inverseOuterDimsPerm[srcPos];
5838 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5839 ShapedType::isDynamic(destShape[destPos])) {
5840 continue;
5841 }
5842 int64_t size = srcShape[srcPos];
5843 if (ShapedType::isDynamic(size))
5844 size = destShape[destPos];
5845 srcShape[srcPos] = size;
5846 destShape[destPos] = size;
5847 changeNeeded = true;
5848 }
5849 return changeNeeded;
5850}
5851
5852LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5853 // TODO: Support Memref PackOp. Temporarily return failure.
5854 if (!packOp.hasPureTensorSemantics())
5855 return failure();
5856
5857 // Fold an pack(unpack(x)) to x.
5858 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5859 if (unPackOp.getSourceType() == packOp.getDestType() &&
5860 !packOp.getPaddingValue() &&
5861 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5862 haveSameTiles(packOp, unPackOp)) {
5863 rewriter.replaceOp(packOp, unPackOp.getSource());
5864 return success();
5865 }
5866 }
5867
5868 // Fold optional PaddingValue operand away if padding is not needed.
5869 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5870 rewriter.startOpModification(packOp);
5871 packOp.getPaddingValueMutable().clear();
5872 rewriter.finalizeOpModification(packOp);
5873 return success();
5874 }
5875
5876 // Insert tensor.cast ops if static shape inference is available..
5877 SmallVector<int64_t> srcShape, destShape;
5878 if (inferStaticShape(packOp, srcShape, destShape)) {
5879 Location loc = packOp.getLoc();
5880 Value source = packOp.getSource();
5881 if (srcShape != packOp.getSourceType().getShape()) {
5882 auto newSrcType = packOp.getSourceType().clone(srcShape);
5883 source =
5884 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5885 }
5886 Value dest = packOp.getDest();
5887 ShapedType originalResultType = packOp.getDestType();
5888 bool needUpdateDestType = (destShape != originalResultType.getShape());
5889 if (needUpdateDestType) {
5890 auto newDestType = packOp.getDestType().clone(destShape);
5891 dest =
5892 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5893 }
5894 rewriter.modifyOpInPlace(packOp, [&] {
5895 packOp.getSourceMutable().assign(source);
5896 packOp.getDestMutable().assign(dest);
5897 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5898 });
5899 // Insert a cast if needed
5900 if (needUpdateDestType) {
5901 rewriter.setInsertionPointAfter(packOp);
5902 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5903 packOp.getResult());
5904 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5905 }
5906 return success();
5907 }
5908
5909 return failure();
5910}
5911
5912template <typename PackOrUnpackOp>
5913static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5914 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5915 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5916 "Function meant for pack/unpack");
5917 // This is a pad if packing only adds ones and we don't transpose dimensions.
5918
5919 // Check that we are not transposing any dimensions.
5920 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5921 int64_t numPackedDims = innerDimsPos.size();
5922 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5923 if (orderedDims != innerDimsPos) {
5924 // Dimensions don't happen in order.
5925 return false;
5926 }
5927
5928 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5929 int64_t packedRank = packedTensorType.getRank();
5930 // At this point we know that we are taking numPackedDims outer
5931 // dimensions and pushing them all the way as the inner most dimensions.
5932 // What's left on the outer most dimensions is, in this order:
5933 // - the factor of the packed dimensions, then
5934 // - the untouched dimensions
5935 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5936 // if all the dimensions that bubble outerward are ones.
5937 // Therefore check that all the dimensions but the numPackedDims inner most
5938 // ones are ones.
5939 return llvm::all_of(
5940 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5941 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5942}
5943
5944bool PackOp::isLikePad() {
5945 auto packedTensorType =
5946 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5947 return isLikePadUnPad(*this, packedTensorType);
5948}
5949
5950::mlir::LogicalResult
5951PackOp::fold(FoldAdaptor adaptor,
5953 if (!hasPureTensorSemantics())
5954 return failure();
5955 std::optional<Attribute> paddingValue;
5956 if (auto pad = adaptor.getPaddingValue())
5957 paddingValue = pad;
5958 if (OpFoldResult reshapedSource = reshapeConstantSource(
5959 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5960 cast<TensorType>(getDestType()), paddingValue)) {
5961 results.push_back(reshapedSource);
5962 return success();
5963 }
5964 return failure();
5965}
5966
5967/// Folds a tensor.cast op into a consuming PackOp op if the
5968/// `tensor.cast` has source that is more static than the consuming op.
5969///
5970/// Example:
5971/// ```mlir
5972/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5973/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5974/// ```
5975///
5976/// folds into:
5977///
5978/// ```mlir
5979/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5980/// ```
5983
5984 LogicalResult matchAndRewrite(PackOp op,
5985 PatternRewriter &rewriter) const override {
5986 // TODO: Support Memref PackOp. Temporarily return failure.
5987 if (!op.hasPureTensorSemantics())
5988 return failure();
5989
5991 return failure();
5992
5993 SmallVector<Type> newResultTypes(op->getResultTypes());
5994 SmallVector<Value> newOperands =
5996
5997 // Get the updated mixed-tile-sizes attribute.
5998 SmallVector<OpFoldResult> newMixedTileSizes =
5999 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
6000
6001 // Clone op.
6002 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6003 // this point. However, in practice, we use them for things that we'd like
6004 // to preserve. Implement a better abstraction.
6005 PackOp newOp =
6006 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6007 op.getInnerDimsPos(), newMixedTileSizes,
6008 op.getPaddingValue(), op.getOuterDimsPerm());
6009 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6010
6011 // Replace op.
6012 Value oldResult = op.getResult();
6013 Value newResult = newOp.getResult();
6015 (newResult.getType() != oldResult.getType())
6016 ? tensor::CastOp::create(rewriter, op->getLoc(),
6017 oldResult.getType(), newResult)
6018 : newResult;
6019
6020 rewriter.replaceOp(op, {replacement});
6021
6022 return success();
6023 }
6024};
6025
6026//===----------------------------------------------------------------------===//
6027// UnPackOp
6028//===----------------------------------------------------------------------===//
6029
6030void UnPackOp::getAsmResultNames(
6031 function_ref<void(Value, StringRef)> setNameFn) {
6032 if (!getResults().empty())
6033 setNameFn(getResult(), "unpack");
6034}
6035
6036// Custom parser for UnPackOp that handles the memref/tensor case distinction
6037ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6038 OpAsmParser::UnresolvedOperand source, dest;
6040 SmallVector<int64_t> staticTiles;
6041 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6042 Type sourceType, destType, resultType;
6043
6044 if (parser.parseOperand(source))
6045 return failure();
6046
6047 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6048 if (parser.parseEqual())
6049 return failure();
6050
6051 SmallVector<int64_t> outerDimsPermVec;
6053 int64_t value;
6054 if (parser.parseInteger(value))
6055 return failure();
6056 outerDimsPermVec.push_back(value);
6057 return success();
6058 }))
6059 return failure();
6060 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6061 }
6062
6063 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6064 return failure();
6065
6066 SmallVector<int64_t> innerDimsPosVec;
6068 int64_t value;
6069 if (parser.parseInteger(value))
6070 return failure();
6071 innerDimsPosVec.push_back(value);
6072 return success();
6073 }))
6074 return failure();
6075 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6076
6077 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6078 return failure();
6079
6080 DenseI64ArrayAttr staticTilesAttr;
6081 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6082 return failure();
6083 for (auto val : staticTilesAttr.asArrayRef())
6084 staticTiles.push_back(val);
6085
6086 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6087 return failure();
6088
6089 if (parser.parseOptionalAttrDict(result.attributes))
6090 return failure();
6091
6092 if (parser.parseColon() || parser.parseType(sourceType))
6093 return failure();
6094
6095 bool hasArrow = succeeded(parser.parseOptionalArrow());
6096 if (hasArrow) {
6097 if (parser.parseType(destType))
6098 return failure();
6099 }
6100
6101 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6102 if (!hasArrow) {
6103 return parser.emitError(parser.getCurrentLocation(),
6104 "pack/unpack requires '->' and destination type");
6105 }
6106
6107 if (!isMemRef)
6108 resultType = destType;
6109
6110 if (parser.resolveOperand(source, sourceType, result.operands) ||
6111 parser.resolveOperand(dest, destType, result.operands))
6112 return failure();
6113
6114 if (!dynamicTiles.empty() &&
6115 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6116 result.operands))
6117 return failure();
6118
6119 result.addAttribute("static_inner_tiles",
6120 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6121 result.addAttribute("inner_dims_pos", innerDimsPos);
6122 if (outerDimsPerm)
6123 result.addAttribute("outer_dims_perm", outerDimsPerm);
6124
6125 SmallVector<int32_t> segmentSizes = {
6126 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6127 result.addAttribute("operandSegmentSizes",
6128 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6129
6130 if (!isMemRef)
6131 result.addTypes(resultType);
6132
6133 return success();
6134}
6135
6136void UnPackOp::print(OpAsmPrinter &p) {
6137 p << " " << getSource();
6138
6139 if (!getOuterDimsPerm().empty()) {
6140 p << " outer_dims_perm = [";
6141 llvm::interleaveComma(getOuterDimsPerm(), p);
6142 p << "]";
6143 }
6144
6145 p << " inner_dims_pos = [";
6146 llvm::interleaveComma(getInnerDimsPos(), p);
6147 p << "]";
6148
6149 p << " inner_tiles = ";
6150 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6151
6152 p << " into " << getDest();
6153
6154 p.printOptionalAttrDict((*this)->getAttrs(),
6155 {"static_inner_tiles", "inner_dims_pos",
6156 "outer_dims_perm", "operandSegmentSizes"});
6157
6158 p << " : " << getSource().getType();
6159 p << " -> " << getDest().getType();
6160}
6161
6162LogicalResult
6163UnPackOp::reifyResultShapes(OpBuilder &builder,
6164 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6165 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6166}
6167
6168DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6169 return getDimAndTileMappingImpl(*this);
6170}
6171
6172SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6173 return getMixedTilesImpl(*this);
6174}
6175
6176SmallVector<int64_t> UnPackOp::getStaticTiles() {
6177 return getStaticTilesImpl(*this);
6178}
6179
6180ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6181 ShapedType destType = getDestType();
6182 int64_t destRank = destType.getRank();
6183 return getSourceType().getShape().take_front(destRank);
6184}
6185
6186SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6187 auto innerDimsPos = getInnerDimsPos();
6188 SmallVector<int64_t> outerDims(getAllOuterDims());
6190
6191 // Recover the original order of the outer dims.
6192 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6193 invertPermutationVector(outerDimPermInv);
6194 if (!outerDimPermInv.empty())
6195 applyPermutationToVector(outerDims, outerDimPermInv);
6196
6197 // Collect the outer dims corresponding to the tilled inner dims.
6198 for (auto index : innerDimsPos)
6199 res.push_back(outerDims[index]);
6200
6201 return res;
6202}
6203
6204LogicalResult UnPackOp::verify() {
6205 return commonVerifierPackAndUnPackOp(*this);
6206}
6207
6208Speculation::Speculatability UnPackOp::getSpeculatability() {
6209 if (!hasPureTensorSemantics())
6211 // See PackOp::getSpeculatability.
6214
6216}
6217
6218void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6219 Value dest, ArrayRef<int64_t> innerDimsPos,
6220 ArrayRef<OpFoldResult> innerTiles,
6221 ArrayRef<int64_t> outerDimsPerm) {
6222 assert(innerDimsPos.size() == innerTiles.size() &&
6223 "number of tile sizes specified must match the specified number of "
6224 "original dimensions to be tiled");
6225 SmallVector<int64_t> staticTileSizes;
6226 SmallVector<Value> dynamicTileSizes;
6227 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6228 build(builder, state, dest.getType(), source, dest,
6229 outerDimsPerm.empty() ? nullptr
6230 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6231 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6232 builder.getDenseI64ArrayAttr(staticTileSizes));
6233}
6234
6235Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6236 Value source,
6237 ArrayRef<OpFoldResult> innerTileSizes,
6238 ArrayRef<int64_t> innerDimsPos,
6239 ArrayRef<int64_t> outerDimsPerm) {
6240 AffineExpr sym0, sym1;
6241 bindSymbols(b.getContext(), sym0, sym1);
6242 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6243 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6244 };
6245
6246 SmallVector<OpFoldResult> mixedSizes;
6247 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6248 for (auto i :
6249 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6250 if (srcType.isDynamicDim(i))
6251 mixedSizes.push_back(
6252 tensor::DimOp::create(b, loc, source, i).getResult());
6253 else
6254 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6255 }
6256 if (!outerDimsPerm.empty()) {
6258 mixedSizes, invertPermutationVector(outerDimsPerm));
6259 }
6260
6261 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6262 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6263
6264 auto elemType = srcType.getElementType();
6265 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6266}
6267
6268UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6269 Value transposedSource,
6270 ArrayRef<int64_t> innerPermutation,
6271 ArrayRef<int64_t> outerPermutation) {
6272 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6273 *this, innerPermutation, outerPermutation);
6274 return UnPackOp::create(b, loc, transposedSource, getDest(),
6275 metadata.innerDimsPos, metadata.innerTiles,
6276 metadata.outerDimsPerm);
6277}
6278
6279/// Returns true if the `srcShape` or `destShape` is different from the one in
6280/// `op` and populates each with the inferred static shape.
6281static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6282 SmallVectorImpl<int64_t> &destShape) {
6283 bool changeNeeded = false;
6284 srcShape.assign(op.getSourceType().getShape().begin(),
6285 op.getSourceType().getShape().end());
6286 destShape.assign(op.getDestType().getShape().begin(),
6287 op.getDestType().getShape().end());
6288 llvm::SmallSetVector<int64_t, 4> innerDims;
6289 innerDims.insert_range(op.getInnerDimsPos());
6290 SmallVector<int64_t> inverseOuterDimsPerm;
6291 if (!op.getOuterDimsPerm().empty())
6292 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6293 int destRank = op.getDestRank();
6294 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6295 if (innerDims.contains(i))
6296 continue;
6297 int64_t srcPos = i;
6298 int64_t destPos = i;
6299 if (!inverseOuterDimsPerm.empty())
6300 srcPos = inverseOuterDimsPerm[destPos];
6301 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6302 ShapedType::isDynamic(destShape[destPos])) {
6303 continue;
6304 }
6305 int64_t size = srcShape[srcPos];
6306 if (ShapedType::isDynamic(size))
6307 size = destShape[destPos];
6308 srcShape[srcPos] = size;
6309 destShape[destPos] = size;
6310 changeNeeded = true;
6311 }
6312 return changeNeeded;
6313}
6314
6315LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6316 PatternRewriter &rewriter) {
6317 // TODO: Support Memref UnPackOp. Temporarily return failure.
6318 if (!unPackOp.hasPureTensorSemantics())
6319 return failure();
6320
6321 /// unpack(pack(x)) -> x
6322 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6323 if (packOp.getSourceType() != unPackOp.getDestType())
6324 return failure();
6325 if (packOp.getPaddingValue() ||
6326 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6327 !haveSameTiles(packOp, unPackOp))
6328 return failure();
6329 rewriter.replaceOp(unPackOp, packOp.getSource());
6330 return success();
6331 }
6332 /// unpack(destinationStyleOp(x)) -> unpack(x)
6333 if (auto dstStyleOp =
6334 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6335 auto destValue = cast<OpResult>(unPackOp.getDest());
6336 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6337 rewriter.modifyOpInPlace(unPackOp,
6338 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6339 return success();
6340 }
6341 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6342 if (unPackOp->hasOneUse()) {
6343 auto extractSliceUser =
6344 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6345 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6346 OpBuilder::InsertionGuard g(rewriter);
6347 rewriter.setInsertionPoint(unPackOp);
6348 auto newDest = tensor::ExtractSliceOp::create(
6349 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6350 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6351 extractSliceUser.getMixedStrides());
6352 rewriter.modifyOpInPlace(unPackOp, [&]() {
6353 unPackOp.setDpsInitOperand(0, newDest);
6354 unPackOp.getResult().setType(newDest.getType());
6355 });
6356 rewriter.replaceOp(extractSliceUser, unPackOp);
6357 return success();
6358 }
6359 }
6360
6361 // Insert tensor.cast ops if static shape inference is available..
6362 SmallVector<int64_t> srcShape, destShape;
6363 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6364 Location loc = unPackOp.getLoc();
6365 Value source = unPackOp.getSource();
6366 if (srcShape != unPackOp.getSourceType().getShape()) {
6367 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6368 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6369 unPackOp.getSource());
6370 }
6371 Value dest = unPackOp.getDest();
6372 if (destShape != unPackOp.getDestType().getShape()) {
6373 auto newDestType = unPackOp.getDestType().clone(destShape);
6374 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6375 unPackOp.getDest());
6376 }
6377 UnPackOp newOp = UnPackOp::create(
6378 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6379 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6380 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6381 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6382 return success();
6383 }
6384
6385 return failure();
6386}
6387
6388bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6389 // Rank-reduced folding is not supported.
6390 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6391 return false;
6392 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6393 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6394 return false;
6395 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6396 SmallVector<int64_t> outerShapeWithoutTranspose =
6398 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
6399 for (auto [pos, tileSize] :
6400 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6401 areOuterDimsTiled[pos] = true;
6402 if (unpackedTypeAfterFold.isDynamicDim(pos))
6403 return false;
6404 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6405 return false;
6406 if (ShapedType::isDynamic(tileSize))
6407 return false;
6408 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6409 unpackedTypeAfterFold.getDimSize(pos);
6410 if (paddingSize >= tileSize)
6411 return false;
6412 }
6413 // extract_slice must not affect dimensions that are not being unpacked
6414 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6415 if (areOuterDimsTiled[pos])
6416 continue;
6417 int64_t dim = outerShapeWithoutTranspose[pos];
6418 if (ShapedType::isDynamic(dim))
6419 return false;
6420 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6421 return false;
6422 }
6423 return true;
6424}
6425
6426bool UnPackOp::isLikeUnPad() {
6427 ShapedType packedTensorType = getSourceType();
6428 return isLikePadUnPad(*this, packedTensorType);
6429}
6430
6431::mlir::LogicalResult
6432UnPackOp::fold(FoldAdaptor adaptor,
6433 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6434 // TODO: Support Memref UnPackOp. Temporarily return failure.
6435 if (!hasPureTensorSemantics())
6436 return failure();
6437
6438 if (OpFoldResult reshapedSource = reshapeConstantSource(
6439 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6440 cast<TensorType>(getResult().getType()))) {
6441 results.push_back(reshapedSource);
6442 return success();
6443 }
6444 return failure();
6445}
6446
6447/// Folds a tensor.cast op into a consuming UnPackOp op if the
6448/// `tensor.cast` has source that is more static than the consuming op.
6449///
6450/// Example:
6451/// ```mlir
6452/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6453/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6454/// ```
6455///
6456/// folds into:
6457///
6458/// ```mlir
6459/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6460/// ```
6461struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6462 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6463
6464 LogicalResult matchAndRewrite(UnPackOp op,
6465 PatternRewriter &rewriter) const override {
6466 // TODO: Support Memref UnPackOp. Temporarily return failure.
6467 if (!op.hasPureTensorSemantics())
6468 return failure();
6469
6471 return failure();
6472
6473 SmallVector<Type> newResultTypes(op->getResultTypes());
6474 SmallVector<Value> newOperands =
6476 Value sourceTensor = newOperands[0];
6477
6478 // Get the updated mixed-tile-sizes attribute.
6480 rewriter, sourceTensor.getType(), op.getMixedTiles());
6481
6482 // Clone op.
6483 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6484 // this point. However, in practice, we use them for things that we'd like
6485 // to preserve. Implement a better abstraction.
6486 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6487 newOperands[1], op.getInnerDimsPos(),
6488 newMixedTileSizes, op.getOuterDimsPerm());
6489 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6490
6491 // Replace op.
6492 Value oldResult = op.getResult();
6493 Value newResult = newOp.getResult();
6495 (newResult.getType() != oldResult.getType())
6496 ? tensor::CastOp::create(rewriter, op->getLoc(),
6497 oldResult.getType(), newResult)
6498 : newResult;
6499
6500 rewriter.replaceOp(op, {replacement});
6501
6502 return success();
6503 }
6504};
6505
6506//===----------------------------------------------------------------------===//
6507// BatchReduceMatmulOp
6508//===----------------------------------------------------------------------===//
6509SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6511 utils::IteratorType::reduction, utils::IteratorType::parallel,
6512 utils::IteratorType::parallel, utils::IteratorType::reduction};
6513}
6514
6516BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6517 AffineExpr d0, d1, d2, d3;
6518 SmallVector<AffineMap> indexingMaps;
6519 bindDims(context, d0, d1, d2, d3);
6520 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6521 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6522 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6523 return indexingMaps;
6524}
6525
6526bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6527 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6528 if (!maps)
6529 return false;
6530 if (maps.size() != 3)
6531 return false;
6532 auto positions = getAffineResultPositions(maps);
6533 if (failed(positions))
6534 return false;
6535 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6536 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6537 (*positions)[2] == SmallVector<int64_t>{1, 2};
6538}
6539unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6540
6541std::string BatchReduceMatmulOp::getLibraryCallName() {
6542 return generateLibraryCallName(getOperation());
6543}
6544
6545/// Check if the op has broadcast and/or transpose semantic. Returns true if
6546/// the user defined indexing maps are not equal to default map.
6547bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6548 SmallVector<AffineMap, 3> defaultMaps =
6549 getDefaultIndexingMaps(this->getContext());
6550 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6551 return defaultMaps != explicitMaps;
6552}
6553
6554/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6555/// broadcast map must include K dimension.
6556/// TODO: Strict inclusion of K dimension in the broadcast map is not
6557/// necessary for both input matrices simultaneously. We can relax this
6558/// condition to have K dimension for one input matrix map and infer the K
6559/// dimension for other input matrix map from the one already having K
6560/// dimension.
6561bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6562 bool isLHS) {
6563 assert(bcastMap.getNumResults() < 3 &&
6564 "Expected less than 3 result dim expr.");
6565 bool isValid = false;
6566 enum Indices { batchPos, mPos, nPos, kPos };
6567 if (bcastMap.getNumResults() == 1) {
6568 AffineExpr expr = bcastMap.getResult(0);
6569 isValid = expr.isFunctionOfDim(kPos);
6570 } else if (bcastMap.getNumResults() == 2) {
6571 AffineExpr expr0 = bcastMap.getResult(0);
6572 AffineExpr expr1 = bcastMap.getResult(1);
6573 isValid =
6574 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6575 expr0.isFunctionOfDim(mPos)) &&
6576 expr1.isFunctionOfDim(kPos))
6577 : ((expr0.isFunctionOfDim(batchPos) &&
6578 expr1.isFunctionOfDim(kPos)) ||
6579 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6580 }
6581 return isValid;
6582}
6583
6584void BatchReduceMatmulOp::regionBuilder(
6587 if (emitError && block.getNumArguments() != 3) {
6588 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6589 << block.getNumArguments();
6590 return;
6591 }
6592 assert(block.getNumArguments() == 3 &&
6593 "BatchReduceMatmulOp regionBuilder expects 3 args");
6594 RegionBuilderHelper helper(b, block);
6595 SmallVector<Value> yields;
6596
6597 auto toType = block.getArgument(2).getType();
6598 Value castValA =
6599 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6600 Value castValB =
6601 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6602 Value mulVal =
6603 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB, emitError);
6604 if (!castValA || !castValB || !mulVal)
6605 return;
6606 Value addVal =
6607 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6608 if (!addVal)
6609 return;
6610 yields.push_back(addVal);
6611 helper.yieldOutputs(yields);
6612}
6613
6614ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6616 SmallVector<Attribute, 3> indexingMapsAttr;
6617 Attribute mapAttr;
6618 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6619 if (parser.parseEqual())
6620 return failure();
6621 if (parser.parseLSquare())
6622 return failure();
6623
6624 do {
6625 if (parser.parseAttribute(mapAttr))
6626 return failure();
6627 if (!isa<AffineMapAttr>(mapAttr)) {
6628 return parser.emitError(parser.getCurrentLocation(),
6629 "expected affine map attribute");
6630 }
6631 indexingMapsAttr.push_back(mapAttr);
6632
6633 if (parser.parseOptionalComma())
6634 break;
6635 } while (true);
6636
6637 if (parser.parseRSquare())
6638 return failure();
6639 }
6640 // Initialize indexingMaps, if not supplied explicitly.
6641 if (indexingMapsAttr.empty()) {
6642 indexingMapsAttr = llvm::map_to_vector(
6643 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6644 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6645 }
6646 result.addAttribute("indexing_maps",
6647 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6648 return ::parseNamedStructuredOp(parser, result,
6649 BatchReduceMatmulOp::getNumRegionArgs(),
6650 BatchReduceMatmulOp::getRegionBuilder());
6651}
6652
6653void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6654 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6655 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6656 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6657
6658 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6659 p << " indexing_maps = [";
6660 llvm::interleaveComma(getIndexingMaps(), p,
6661 [&](Attribute attr) { p.printAttribute(attr); });
6662 p << "]";
6663 }
6664
6665 SmallVector<StringRef, 3> elidedAttrs = {
6666 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6667 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6668 elidedAttrs);
6669}
6670
6671/// Verify the user defined indexing maps.
6672LogicalResult BatchReduceMatmulOp::verify() {
6673 // Verification of pure batch_reduce_matmul is handled by
6674 // verifyStructuredOpInterface().
6675 if (!hasUserDefinedMaps())
6676 return success();
6677
6678 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6680 return failure();
6681 }
6682 return success();
6683}
6684LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6686 return memref::foldMemRefCast(*this);
6687}
6688void BatchReduceMatmulOp::getEffects(
6690 &effects) {
6691 if (hasPureTensorSemantics())
6692 return;
6693 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6694}
6695
6696Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6697 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6698}
6699
6700} // namespace linalg
6701} // namespace mlir
6702
6703//===----------------------------------------------------------------------===//
6704// LinalgDialect
6705//===----------------------------------------------------------------------===//
6706
6707void LinalgDialect::getCanonicalizationPatterns(
6708 RewritePatternSet &results) const {
6709 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6710 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6711}
6712
6713Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6714 Attribute value, Type type,
6715 Location loc) {
6716 return arith::ConstantOp::materialize(builder, value, type, loc);
6717}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:59
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
result_iterator result_begin()
Definition Operation.h:442
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
unsigned getNumOperands()
Definition Operation.h:375
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:426
result_iterator result_end()
Definition Operation.h:443
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:97
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:239
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1306
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override