MLIR 23.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Attributes.h"
30#include "mlir/IR/Builders.h"
39
40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/FormatVariadic.h"
47#include "llvm/Support/InterleavedRange.h"
48#include "llvm/Support/LogicalResult.h"
49#include "llvm/Support/MathExtras.h"
50#include "llvm/Support/raw_ostream.h"
51#include <cassert>
52#include <optional>
53
54using namespace mlir;
55using namespace mlir::linalg;
56
57/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
59 int64_t dim) {
60 auto type = cast<ShapedType>(v.getType());
61 if (!type.isDynamicDim(dim))
62 return builder.getIndexAttr(type.getDimSize(dim));
63
64 return getAsOpFoldResult(
66 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67 return tensor::DimOp::create(builder, loc, v, dim);
68 })
69 .Case<MemRefType>([&](MemRefType t) -> Value {
70 return memref::DimOp::create(builder, loc, v, dim);
71 }));
72}
73
74/// Returns a memref.subview or a tensor.extract_slice based on the type of the
75/// `source`.
79 ArrayRef<OpFoldResult> strides) {
81 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
83 strides);
84 })
85 .Case<MemRefType>([&](MemRefType type) -> Operation * {
86 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
87 strides);
88 })
89 .Default([&](Type t) -> Operation * { return nullptr; });
90}
91
92//===----------------------------------------------------------------------===//
93// Helper functions
94//===----------------------------------------------------------------------===//
95
97 int64_t dim) {
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99 return b.createOrFold<memref::DimOp>(loc, source, dim);
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101 return b.createOrFold<tensor::DimOp>(loc, source, dim);
102 llvm_unreachable("Expected MemRefType or TensorType");
103}
104
106 int64_t dim) {
107 auto shapedType = llvm::cast<ShapedType>(source.getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109 return createOrFoldDimOp(b, loc, source, dim);
110 return b.getIndexAttr(shapedType.getDimSize(dim));
111}
112
113//===----------------------------------------------------------------------===//
114// Support for named Linalg ops defined in ods-gen.
115//===----------------------------------------------------------------------===//
116
120
121/// Fills the region of a structured operation using the provided
122/// `regionBuilder`. The method is used by both named structured ops created by
123/// ods-gen and by manually defined C++ ops. It is called by both builders and
124/// parsers and creates a block with arguments corresponding to the elemental
125/// types of `inputTypes` and `outputTypes`.
126static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127 TypeRange inputTypes, TypeRange outputTypes,
130 RegionBuilderFn regionBuilder) {
131 SmallVector<Type, 8> argTypes;
133 for (auto containers : {inputTypes, outputTypes}) {
134 for (auto t : containers) {
135 argTypes.push_back(
136 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
137
138 // TODO: Pass in a proper location here.
139 argLocs.push_back(opBuilder.getUnknownLoc());
140 }
141 }
142
143 // RAII.
144 OpBuilder::InsertionGuard guard(opBuilder);
145 Block *body =
146 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
147
148 opBuilder.setInsertionPointToStart(body);
149 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
150 regionBuilder(b, *body, attrs, emitError);
151
152 // indexing_maps is an auto-generated method.
153
154 // iterator_types is an auto-generated method.
155}
156
157/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
158/// The result types are derived automatically if `resultTensorTypes` is none.
159/// The body of the operation is filled using `regionBuilder`. All ods-gen
160/// created structured operations use the method to implement their builders.
162 std::optional<TypeRange> resultTensorTypes,
163 ValueRange inputs, ValueRange outputs,
164 ArrayRef<NamedAttribute> attributes,
165 RegionBuilderFn regionBuilder) {
166 // Derive the result types if needed.
167 SmallVector<Type> derivedResultTypes =
168 resultTensorTypes.value_or(TypeRange());
169 if (!resultTensorTypes)
170 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
171 llvm::IsaPred<RankedTensorType>);
172
173 state.addOperands(inputs);
174 state.addOperands(outputs);
175 state.addTypes(derivedResultTypes);
176
177 state.addAttributes(attributes);
178 state.addAttribute(
179 "operandSegmentSizes",
180 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
181 static_cast<int32_t>(outputs.size())}));
182
183 // Create and fill the region of the structured operation.
184 Region &region = *state.addRegion();
185 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
186 state.attributes.getAttrs(), /*emitError=*/{},
187 regionBuilder);
188}
189
191 std::optional<TypeRange> resultTensorTypes,
192 ValueRange inputs, ValueRange outputs,
193 ArrayRef<NamedAttribute> attributes,
194 RegionBuilderFn regionBuilder,
195 ArrayRef<AffineMap> indexingMaps) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector<Attribute, 3> indexingMapsAttrVal;
198 indexingMapsAttrVal =
199 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
200 return AffineMapAttr::get(map);
201 });
202 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
203 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
204 attributes, regionBuilder);
205}
206
208 std::optional<TypeRange> resultTensorTypes,
209 ValueRange inputs, ValueRange outputs,
210 ArrayRef<NamedAttribute> attributes,
211 RegionBuilderFn regionBuilder,
212 ArrayRef<AffineMap> indexingMaps) {
213 // Initialize indexingMaps attribute, for BatchMatmulOp.
214 SmallVector<Attribute, 4> indexingMapsAttrVal;
215 indexingMapsAttrVal =
216 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
217 return AffineMapAttr::get(map);
218 });
219 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
220 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
221 attributes, regionBuilder);
222}
223
225 std::optional<TypeRange> resultTensorTypes,
226 ValueRange inputs, ValueRange outputs,
227 ArrayRef<NamedAttribute> attributes,
228 RegionBuilderFn regionBuilder,
229 ArrayRef<AffineMap> indexingMaps) {
230 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
231 SmallVector<Attribute, 4> indexingMapsAttrVal;
232 indexingMapsAttrVal =
233 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
234 return AffineMapAttr::get(map);
235 });
236 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
237 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
238 attributes, regionBuilder);
239}
240
241/// Common parsing used for both named structured ops created by ods-gen and by
242/// manually defined C++ ops. Does not handle regions.
243static ParseResult
245 SmallVectorImpl<Type> &inputTypes,
246 SmallVectorImpl<Type> &outputTypes,
247 bool addOperandSegmentSizes = true) {
248 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
250 outputsOperands;
251
252 if (succeeded(parser.parseOptionalLess())) {
253 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
254 return failure();
255 }
256 attrsLoc = parser.getCurrentLocation();
257 if (parser.parseOptionalAttrDict(result.attributes))
258 return failure();
259
260 if (succeeded(parser.parseOptionalKeyword("ins"))) {
261 if (parser.parseLParen())
262 return failure();
263
264 inputsOperandsLoc = parser.getCurrentLocation();
265 if (parser.parseOperandList(inputsOperands) ||
266 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
267 return failure();
268 }
269
270 if (succeeded(parser.parseOptionalKeyword("outs"))) {
271 outputsOperandsLoc = parser.getCurrentLocation();
272 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
273 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
274 return failure();
275 }
276
277 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 result.operands) ||
279 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
280 result.operands))
281 return failure();
282
283 if (addOperandSegmentSizes) {
284 // This is a bit complex because we're trying to be backward compatible with
285 // operation syntax that mix the inherent attributes and the discardable
286 // ones in the same dictionary. If the properties are used, we append the
287 // operandSegmentSizes there directly. Otherwise we append it to the
288 // discardable attributes dictionary where it is handled by the generic
289 // Operation::create(...) method.
290 if (result.propertiesAttr) {
291 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
292 attrs.append("operandSegmentSizes",
294 {static_cast<int32_t>(inputsOperands.size()),
295 static_cast<int32_t>(outputsOperands.size())}));
296 result.propertiesAttr = attrs.getDictionary(parser.getContext());
297 } else {
298 result.addAttribute("operandSegmentSizes",
300 {static_cast<int32_t>(inputsOperands.size()),
301 static_cast<int32_t>(outputsOperands.size())}));
302 }
303 }
304 if (!result.propertiesAttr) {
305 std::optional<RegisteredOperationName> info =
306 result.name.getRegisteredInfo();
307 if (info) {
308 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
309 return parser.emitError(attrsLoc)
310 << "'" << result.name.getStringRef() << "' op ";
311 })))
312 return failure();
313 }
314 }
315 return success();
316}
317
319 ValueRange outputs) {
320 if (!inputs.empty())
321 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
322 if (!outputs.empty())
323 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
324}
325
326//===----------------------------------------------------------------------===//
327// Specific parsing and printing for named structured ops created by ods-gen.
328//===----------------------------------------------------------------------===//
329
331 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
332 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
333 RegionBuilderFn regionBuilder, SMLoc loc) {
334 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
335 return parser.emitError(
336 parser.getCurrentLocation(),
337 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
338 "region expects {0} args, got {1}",
339 numRegionArgs, inputTypes.size() + outputTypes.size()));
340 }
341
342 OpBuilder opBuilder(parser.getContext());
343 ParseResult result = success();
345 opBuilder, region, inputTypes, outputTypes, attrs,
346 [&]() {
347 result = failure();
348 return parser.emitError(loc);
349 },
350 regionBuilder);
351 return result;
352}
353
354static ParseResult
356 SmallVectorImpl<Type> &resultTypes) {
357 if (parser.parseOptionalArrowTypeList(resultTypes))
358 return failure();
359 return success();
360}
361
362static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
364 unsigned numRegionArgs,
365 RegionBuilderFn regionBuilder) {
366 // TODO: Enable when ods-gen supports captures.
367 SmallVector<Type, 1> inputTypes, outputTypes;
368 SMLoc loc = parser.getCurrentLocation();
369 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
370 return failure();
371
372 // Parse optional attributes.
373 if (parser.parseOptionalAttrDict(result.attributes))
374 return failure();
375
376 // TODO: consider merging results parsing into region parsing.
377 // Need to wait for declarative assembly resolution to decide.
378 SmallVector<Type, 1> outputTensorsTypes;
379 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
380 return failure();
381 result.addTypes(outputTensorsTypes);
382
383 std::unique_ptr<Region> region = std::make_unique<Region>();
384 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
385 outputTypes, result.attributes.getAttrs(),
386 regionBuilder, loc))
387 return failure();
388 result.addRegion(std::move(region));
389
390 return success();
391}
392
394 TypeRange resultTypes) {
395 if (resultTypes.empty())
396 return;
397 p.printOptionalArrowTypeList(resultTypes);
398}
399
401 ValueRange inputs, ValueRange outputs,
402 ArrayRef<StringRef> elidedAttrs = {}) {
403 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
404
405 // Printing is shared with generic ops, except for the region and
406 // attributes.
407 printCommonStructuredOpParts(p, inputs, outputs);
408
409 // Results printing.
411
412 // Region is elided.
413}
414
415//===----------------------------------------------------------------------===//
416// Region builder helper.
417// TODO: Move this to a utility library.
418// The public methods on this class are referenced directly from generated code.
419// Helper build the unary, binary, and type conversion functions defined by the
420// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
421// class.
422//
423// Implementations of the math functions must be polymorphic over numeric types,
424// internally performing necessary casts. If the function application makes no
425// sense, then the only recourse is to assert and return nullptr. This can be
426// extended later if it becomes possible to fail construction of the region. The
427// invariant should be enforced at a higher level.
428//
429// TODO: These helpers are currently type polymorphic over the class of integer
430// and floating point types, but they will not internally cast within bit
431// widths of a class (mixed precision such as i8->i32) or across classes
432// (i.e. mixed float and integer). Many such combinations are ambiguous or need
433// to be handled with care and work is being considered to extend the op
434// language to make such cases explicit. In the mean-time, violating this will
435// fail verification, which is deemed acceptable.
436//===----------------------------------------------------------------------===//
437
438namespace {
439
440class RegionBuilderHelper {
441public:
442 RegionBuilderHelper(OpBuilder &builder, Block &block)
443 : builder(builder), block(block) {}
444
445 // Build the unary functions defined by OpDSL.
446 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
447 function_ref<InFlightDiagnostic()> emitError = {}) {
448 if (!isFloatingPoint(arg)) {
449 if (emitError) {
450 emitError() << "unsupported non numeric type";
451 return nullptr;
452 }
453 llvm_unreachable("unsupported non numeric type");
454 }
455 OpBuilder::InsertionGuard g(builder);
456 builder.setInsertionPointToEnd(&block);
457 switch (unaryFn) {
458 case UnaryFn::exp:
459 return math::ExpOp::create(builder, arg.getLoc(), arg);
460 case UnaryFn::log:
461 return math::LogOp::create(builder, arg.getLoc(), arg);
462 case UnaryFn::abs:
463 return math::AbsFOp::create(builder, arg.getLoc(), arg);
464 case UnaryFn::ceil:
465 return math::CeilOp::create(builder, arg.getLoc(), arg);
466 case UnaryFn::floor:
467 return math::FloorOp::create(builder, arg.getLoc(), arg);
468 case UnaryFn::negf:
469 return arith::NegFOp::create(builder, arg.getLoc(), arg);
470 case UnaryFn::reciprocal: {
471 Attribute oneAttr = builder.getOneAttr(arg.getType());
472 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
473 ::cast<TypedAttr>(oneAttr));
474 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
475 }
476 case UnaryFn::round:
477 return math::RoundOp::create(builder, arg.getLoc(), arg);
478 case UnaryFn::sqrt:
479 return math::SqrtOp::create(builder, arg.getLoc(), arg);
480 case UnaryFn::rsqrt:
481 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
482 case UnaryFn::square:
483 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
484 case UnaryFn::tanh:
485 return math::TanhOp::create(builder, arg.getLoc(), arg);
486 case UnaryFn::erf:
487 return math::ErfOp::create(builder, arg.getLoc(), arg);
488 }
489 if (emitError) {
490 emitError() << "unsupported unary function";
491 return nullptr;
492 }
493 llvm_unreachable("unsupported unary function");
494 }
495
496 // Build the binary functions defined by OpDSL.
497 // If emitError is provided, an error will be emitted if the operation is not
498 // supported and a nullptr will be returned, otherwise an assertion will be
499 // raised.
500 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
501 function_ref<InFlightDiagnostic()> emitError = {}) {
502 bool allComplex = isComplex(arg0) && isComplex(arg1);
503 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
504 bool allInteger = isInteger(arg0) && isInteger(arg1);
505 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
506 arg1.getType().getIntOrFloatBitWidth() == 1;
507 if (!allComplex && !allFloatingPoint && !allInteger) {
508 if (emitError) {
509 emitError()
510 << "Cannot build binary Linalg operation: expects allComplex, "
511 "allFloatingPoint, or allInteger, got "
512 << arg0.getType() << " and " << arg1.getType();
513 return nullptr;
514 }
515 llvm_unreachable("unsupported non numeric type");
516 }
517 OpBuilder::InsertionGuard g(builder);
518 builder.setInsertionPointToEnd(&block);
519 switch (binaryFn) {
520 case BinaryFn::add:
521 if (allComplex)
522 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
523 if (allFloatingPoint)
524 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
525 if (allBool)
526 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
527 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
528 case BinaryFn::sub:
529 if (allComplex)
530 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
531 if (allFloatingPoint)
532 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
533 if (allBool) {
534 if (emitError) {
535 emitError() << "unsupported operation: sub with bools";
536 return nullptr;
537 }
538 llvm_unreachable("unsupported operation: sub with bools");
539 }
540 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
541 case BinaryFn::mul:
542 if (allComplex)
543 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
544 if (allFloatingPoint)
545 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
546 if (allBool)
547 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
549 case BinaryFn::div:
550 if (allComplex)
551 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
552 if (allFloatingPoint)
553 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
554 if (allBool) {
555 if (emitError) {
556 emitError() << "unsupported operation: div with bools";
557 return nullptr;
558 }
559 llvm_unreachable("unsupported operation: div with bools");
560 }
561 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
562 case BinaryFn::div_unsigned:
563 if (!allInteger || allBool) {
564 if (emitError) {
565 emitError() << "unsupported operation: unsigned div not on uint";
566 return nullptr;
567 }
568 llvm_unreachable("unsupported operation: unsigned div not on uint");
569 }
570 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
571 case BinaryFn::max_signed:
572 assert(!allComplex);
573 if (allFloatingPoint)
574 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
575 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
576 case BinaryFn::min_signed:
577 assert(!allComplex);
578 if (allFloatingPoint)
579 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
580 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
581 case BinaryFn::max_unsigned:
582 assert(!allComplex);
583 if (!allInteger || allBool) {
584 if (emitError) {
585 emitError() << "unsupported operation: unsigned max not on uint";
586 return nullptr;
587 }
588 llvm_unreachable("unsupported operation: unsigned max not on uint");
589 }
590 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
591 case BinaryFn::min_unsigned:
592 assert(!allComplex);
593 if (!allInteger || allBool) {
594 if (emitError) {
595 emitError() << "unsupported operation: unsigned min not on uint";
596 return nullptr;
597 }
598 llvm_unreachable("unsupported operation: unsigned min not on uint");
599 }
600 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
601 case BinaryFn::powf:
602 assert(allFloatingPoint);
603 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
604 }
605 if (emitError) {
606 emitError() << "unsupported binary function";
607 return nullptr;
608 }
609 llvm_unreachable("unsupported binary function");
610 }
611
612 // Build the ternary functions defined by OpDSL.
613 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
614 function_ref<InFlightDiagnostic()> emitError = {}) {
615 bool headBool =
616 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
617 bool tailFloatingPoint =
618 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
619 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
620 OpBuilder::InsertionGuard g(builder);
621 builder.setInsertionPointToEnd(&block);
622 switch (ternaryFn) {
623 case TernaryFn::select:
624 if (!headBool && !(tailFloatingPoint || tailInteger))
625 llvm_unreachable("unsupported non numeric type");
626 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
627 }
628 if (emitError) {
629 emitError() << "unsupported ternary function";
630 return nullptr;
631 }
632 llvm_unreachable("unsupported ternary function");
633 }
634
635 // Build the type functions defined by OpDSL.
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
637 function_ref<InFlightDiagnostic()> emitError = {}) {
638 switch (typeFn) {
639 case TypeFn::cast_signed:
640 return cast(toType, operand, false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand, true);
643 }
644 if (emitError) {
645 emitError() << "unsupported type conversion function";
646 return nullptr;
647 }
648 llvm_unreachable("unsupported type conversion function");
649 }
650
651 void yieldOutputs(ValueRange values) {
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
656 }
657
658 Value constant(const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr = parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
665 }
666
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
671 }
672
673 Type getIntegerType(unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
675 }
676
677 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
679
680private:
681 // Generates operations to cast the given operand to a specified type.
682 // If the cast cannot be performed, a warning will be issued and the
683 // operand returned as-is (which will presumably yield a verification
684 // issue downstream).
685 Value cast(Type toType, Value operand, bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.getLoc();
689 if (isa<UnknownLoc>(loc)) {
690 if (operand.getDefiningOp())
691 loc = operand.getDefiningOp()->getLoc();
692 else if (operand.getParentBlock() &&
693 operand.getParentBlock()->getParentOp())
694 loc = operand.getParentBlock()->getParentOp()->getLoc();
695 }
696 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
697 }
698
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.getType());
701 }
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.getType());
704 }
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.getType());
707 }
708
709 OpBuilder &builder;
710 Block &block;
711};
712
713} // namespace
714
715//===----------------------------------------------------------------------===//
716// CopyOp
717//===----------------------------------------------------------------------===//
718
719namespace {
720
721struct EraseSelfCopy : OpRewritePattern<CopyOp> {
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter) const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
726 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
727 if (copyOp.hasPureBufferSemantics())
728 rewriter.eraseOp(copyOp);
729 else
730 rewriter.replaceOp(copyOp, copyOp.getInputs());
731
732 return success();
733 }
734};
735
736} // namespace
737
738void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.add<EraseSelfCopy>(context);
741}
742
743//===----------------------------------------------------------------------===//
744// FillOp
745//===----------------------------------------------------------------------===//
746
747namespace {
748
749/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
750///
751/// For such op chains, we can create new linalg.fill ops with the result
752/// type of the tensor.expand/collapse_shape op.
753template <typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter) const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
759 if (!oldFill)
760 return failure();
761
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
765
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
770 } else {
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
774 }
775 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
776 ValueRange{newInit});
777 return success();
778 }
779};
780
781/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
782/// filling value are the same.
783struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
785
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter) const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
789 if (!fillOp)
790 return failure();
791
792 // We can only fold if the padding value is the same as the original
793 // filling value.
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
796 return failure();
797
798 ReifiedRankedShapedTypeDims reifiedShape;
799 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
800 return rewriter.notifyMatchFailure(
801 padOp, "failed to reify tensor.pad op result shape");
802
803 auto emptyTensor =
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
806 Value replacement =
807 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
808 ValueRange{emptyTensor})
809 .getResult(0);
810 if (replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
812 padOp.getResultType(), replacement);
813 }
814 rewriter.replaceOp(padOp, replacement);
815 return success();
816 }
817};
818
819/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
820/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
821/// filling value are the same.
822struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
824
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter) const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
828 if (!srcPadOp)
829 return failure();
830
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
832 return failure();
833
834 // Walk back the tensor.insert_slice chain and find the first destination
835 // value at the start of the chain.
836 Value firstDest = insertOp.getDest();
837 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
839 return failure();
840
841 // Make sure the range of values accessed are disjoint. Without this, we
842 // cannot fold tensor.pad away.
843 bool disjoint = false;
844 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
845 // If the dimension has dynamic offset/size, we cannot guarantee
846 // disjoint. So just skip it.
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
850 continue;
851
852 // Get the range start and end, inclusively for both.
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
860 disjoint = true;
861 break;
862 }
863 }
864
865 if (!disjoint)
866 break;
867 firstDest = prevOp.getDest();
868 }
869
870 // Check whether the first destination is a fill op. For overlapped cases,
871 // this also cannot be true.
872 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
873 if (!dstFillOp)
874 return failure();
875
876 // We can only fold if the padding value is the same as the original
877 // filling value.
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
880 return failure();
881
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
884
885 Location loc = insertOp.getLoc();
886 MLIRContext *context = getContext();
887
888 AffineExpr sym0, sym1;
889 bindSymbols(context, sym0, sym1);
890 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
891
892 // Calculate the new offsets for the insert. It should be the old offsets
893 // plus low padding sizes.
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
898 }
899
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
904 newSizes.push_back(
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
906 .getResult());
907 } else {
908 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
909 }
910 }
911
912 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
915 return success();
916 }
917};
918
919/// Fold tensor.extract(linalg.fill(<input>)) into <input>
920struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
921public:
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
923
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter) const override {
926 // See if tensor input of tensor.extract op is the result of a linalg.fill
927 // op.
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
929 if (!fillOp)
930 return failure();
931
932 // Get scalar input operand of linalg.fill op.
933 Value extractedScalar = fillOp.getInputs()[0];
934
935 // Replace tensor.extract op with scalar value used to fill the tensor.
936 rewriter.replaceOp(extractOp, extractedScalar);
937 return success();
938 }
939};
940
941/// Folds pack(fill) into a single fill op if
942/// 1. The pack op does not have padding value, or
943/// 2. The filled value and padding value are the same.
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
947 if (!fillOp)
948 return failure();
949
950 if (auto paddingValue = packOp.getPaddingValue())
951 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
952 return failure();
953
954 Value packOpDest = packOp.getDest();
955 if (!packOpDest.hasOneUse())
956 return failure();
957
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
959 packOp.getDest());
960}
961
962/// Wrapper pattern that applies foldFillPackIntoFillOp method.
963struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
964public:
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
967
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter) const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
971 if (failed(fillOp))
972 return failure();
973 rewriter.replaceOp(packOp, fillOp.value().result());
974 return success();
975 }
976};
977
978/// Fold fill with copy.
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
981
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter) const override {
984 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
985 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
986 fillOp.getInputs(),
987 copyOp.getOutputs());
988 return success();
989 }
990 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
991 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
992 fillOp.getOutputs());
993 return success();
994 }
995 return failure();
996 }
997};
998
999/// Fold fill with transpose.
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1002
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter) const override {
1005 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1006 rewriter.replaceOpWithNewOp<FillOp>(
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1009 return success();
1010 }
1011 return failure();
1012 }
1013};
1014
1015/// Fold a concat with all elements being fills of the same value
1016/// into a fill of the concat result shape.
1017struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1019
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter) const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1024 return failure();
1025 }
1026
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1028 if (!firstFillOp) {
1029 return failure();
1030 }
1031 // Prefetch the fill value.
1032 OpFoldResult firstFillVal =
1033 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1034 // Collect all the outs values for the fill operations.
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1037
1038 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1040 if (!fillOp) {
1041 return false;
1042 }
1043
1044 OpFoldResult fillVal =
1045 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1046 if (fillVal != firstFillVal)
1047 return false;
1048
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1050 return true;
1051 };
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1054 return rewriter.notifyMatchFailure(
1055 concatOp, "not all operands are defined by a compatible fill op");
1056 }
1057
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1060 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1061 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1062 return success();
1063 }
1064};
1065
1066} // namespace
1067
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1075}
1076
1077//===----------------------------------------------------------------------===//
1078// GenericOp
1079//===----------------------------------------------------------------------===//
1080
1082 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1083 ValueRange outputs,
1084 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1085 SmallVector<Type, 4> blockArgTypes;
1086 SmallVector<Location, 4> blockArgLocs;
1087 for (ValueRange container : {inputs, outputs}) {
1088 for (Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1091 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1092 blockArgLocs.push_back(v.getLoc());
1093 }
1094 }
1095
1096 OpBuilder::InsertionGuard guard(builder);
1097 Block *bodyBlock =
1098 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1099 bodyBuild(builder, loc, bodyBlock->getArguments());
1100}
1101
1102void GenericOp::getAsmBlockArgumentNames(Region &region,
1103 OpAsmSetValueNameFn setNameFn) {
1104 for (Value v : getRegionInputArgs())
1105 setNameFn(v, "in");
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v, "out");
1108}
1109
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1112 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1114 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1119 if (bodyBuild)
1120 buildGenericRegion(builder, result.location, *result.regions.front(),
1121 inputs, outputs, bodyBuild);
1122}
1123
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1126 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1129 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder, result, resultTensorTypes, inputs, outputs,
1132 builder.getAffineMapArrayAttr(indexingMaps),
1133 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1134 iteratorTypes,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1137 }))),
1138 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1139 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1148 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1149 ArrayRef<NamedAttribute> attributes) {
1150 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, ValueRange inputs,
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1161 /*doc=*/"",
1162 /*libraryCall=*/"", bodyBuild, attributes);
1163}
1164
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1167 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1169 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1172 iteratorTypes,
1173 /*doc=*/"",
1174 /*libraryCall=*/"", bodyBuild, attributes);
1175}
1176
1177void GenericOp::print(OpAsmPrinter &p) {
1178 p << " ";
1179
1180 // Print extra attributes.
1181 auto genericAttrNames = linalgTraitAttrNames();
1182
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1191 // Convert IteratorType enums into the string representation. This is
1192 // needed, because tests still use the old format when 'iterator_types'
1193 // attribute is represented as an array of strings.
1194 // TODO: Remove this conversion once tests are fixed.
1195 SmallVector<Attribute> iteratorTypeNames =
1196 llvm::to_vector(llvm::map_range(
1197 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1198 return StringAttr::get(getContext(), stringifyIteratorType(t));
1199 }));
1200
1201 genericAttrs.emplace_back(
1202 getIteratorTypesAttrName(),
1203 ArrayAttr::get(getContext(), iteratorTypeNames));
1204 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1205 genericAttrs.push_back(attr);
1206 }
1207 }
1208 if (!genericAttrs.empty()) {
1209 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1210 p << genericDictAttr;
1211 }
1212
1213 // Printing is shared with named ops, except for the region and attributes
1214 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1215
1216 genericAttrNames.push_back("operandSegmentSizes");
1217 genericAttrNamesSet.insert(genericAttrNames.back());
1218
1219 bool hasExtraAttrs = false;
1220 for (NamedAttribute n : (*this)->getAttrs()) {
1221 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1222 break;
1223 }
1224 if (hasExtraAttrs) {
1225 p << " attrs = ";
1226 p.printOptionalAttrDict((*this)->getAttrs(),
1227 /*elidedAttrs=*/genericAttrNames);
1228 }
1229
1230 // Print region.
1231 if (!getRegion().empty()) {
1232 p << ' ';
1233 p.printRegion(getRegion());
1234 }
1235
1236 // Print results.
1237 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1238}
1239
1240ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1241 DictionaryAttr dictAttr;
1242 // Parse the core linalg traits that must check into a dictAttr.
1243 // The name is unimportant as we will overwrite result.attributes.
1244 // The core linalg traits must contain the information necessary to pass the
1245 // verifier.
1246 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1247 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1248 return failure();
1249 result.attributes.assign(dictAttr.getValue().begin(),
1250 dictAttr.getValue().end());
1251
1252 // Convert array of string into an array of IteratorType enums. This is
1253 // needed, because tests still use the old format when 'iterator_types'
1254 // attribute is represented as an array of strings.
1255 // TODO: Remove this conversion once tests are fixed.
1256 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1257 result.attributes.get(getIteratorTypesAttrName(result.name)));
1258 if (!iteratorTypes) {
1259 return parser.emitError(attributeLocation)
1260 << "expected " << getIteratorTypesAttrName(result.name)
1261 << " array attribute";
1262 }
1263
1264 SmallVector<Attribute> iteratorTypeAttrs;
1265
1266 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1267 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1268 if (!maybeIteratorType.has_value())
1269 return parser.emitError(parser.getCurrentLocation())
1270 << "unexpected iterator_type (" << s << ")";
1271
1272 iteratorTypeAttrs.push_back(
1273 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1274 }
1275 result.attributes.set(getIteratorTypesAttrName(result.name),
1276 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1277
1278 // Parsing is shared with named ops, except for the region.
1279 SmallVector<Type, 1> inputTypes, outputTypes;
1280 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1281 return failure();
1282
1283 // Optional attributes may be added.
1284 if (succeeded(parser.parseOptionalKeyword("attrs")))
1285 if (failed(parser.parseEqual()) ||
1286 failed(parser.parseOptionalAttrDict(result.attributes)))
1287 return failure();
1288
1289 std::unique_ptr<Region> region = std::make_unique<Region>();
1290 if (parser.parseRegion(*region, {}))
1291 return failure();
1292 result.addRegion(std::move(region));
1293
1294 // Generic ops may specify that a subset of its outputs are tensors. Such
1295 // outputs are specified in the result type.
1296 // TODO: may need to move output parsing before region parsing.
1297 // Need to wait for declarative assembly resolution to decide.
1298 SmallVector<Type, 1> outputTensorsTypes;
1299 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1300 return failure();
1301 result.addTypes(outputTensorsTypes);
1302
1303 return success();
1304}
1305
1308 &effects,
1309 LinalgOp linalgOp) {
1310 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1311 if (!llvm::isa<MemRefType>(operand.getType()))
1312 continue;
1313 effects.emplace_back(
1314 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1315 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1316 }
1317
1318 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1319 if (!llvm::isa<MemRefType>(operand.get().getType()))
1320 continue;
1321 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1322 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1323 /*effectOnFullRegion=*/true,
1325 }
1326 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1327 /*effectOnFullRegion=*/true,
1329 }
1330}
1331
1332void GenericOp::getEffects(
1333 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1334 &effects) {
1335 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1336}
1337
1340 // Operands with value semantics are speculatable, while operands with memory
1341 // semantics are not.
1342 if (!linalgOp.hasPureTensorSemantics())
1344 // The body of the op can still have speculation in its region.
1346}
1347
1348Speculation::Speculatability GenericOp::getSpeculatability() {
1349 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1350}
1351
1352namespace {
1353
1354/// Remove linalg operations that are just copying the values from inputs to
1355/// results. In the memref case, the operation must be copying to and from the
1356/// same value. Requirements are:
1357/// 1) All iterator types are parallel
1358/// 2) The body contains just a yield operation with the yielded values being
1359/// the arguments corresponding to the operands.
1360template <typename OpTy>
1361struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1362 using OpRewritePattern<OpTy>::OpRewritePattern;
1363
1364 LogicalResult matchAndRewrite(OpTy linalgOp,
1365 PatternRewriter &rewriter) const override {
1366 // All indexing maps must be equal. It follows that they are permutations.
1367 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1368 return failure();
1369
1370 // Check that the body of the linalg operation is just a linalg.yield
1371 // operation.
1372 Block &body = linalgOp->getRegion(0).front();
1373 if (!llvm::hasSingleElement(body))
1374 return failure();
1375 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1376 if (!yieldOp)
1377 return failure();
1378
1379 // In the buffer case, we need to check exact buffer equality.
1380 if (linalgOp.hasPureBufferSemantics()) {
1381 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1382 linalgOp.getDpsInputOperand(0)->get() !=
1383 linalgOp.getDpsInitOperand(0)->get()) {
1384 return rewriter.notifyMatchFailure(
1385 linalgOp, "expected single input and output to be the same value");
1386 }
1387
1388 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1389 if (!yieldArg || yieldArg.getOwner() != &body) {
1390 return rewriter.notifyMatchFailure(linalgOp,
1391 "cannot fold fill-like op");
1392 }
1393
1394 rewriter.eraseOp(linalgOp);
1395 return success();
1396 }
1397
1398 if (!linalgOp.hasPureTensorSemantics()) {
1399 return rewriter.notifyMatchFailure(
1400 linalgOp, "mixed semantics is not supported yet");
1401 }
1402
1403 // Get the argument number of the returned values. That is the operand
1404 // number to use for replacing uses of this operation.
1405 SmallVector<Value> returnedArgs;
1406 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1407 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1408 if (!yieldArg || yieldArg.getOwner() != &body)
1409 return failure();
1410 unsigned argumentNumber = yieldArg.getArgNumber();
1411 Value returnedArg = linalgOp->getOperand(argumentNumber);
1412 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1413 // The input can have a different type than the result, e.g. a dynamic
1414 // input dimension can be turned into a static output dimension.
1415 Type returnType = returnedArg.getType();
1416 if (returnType != resultType) {
1417 // Distinguish between sparse conversion or dense tensor casting.
1418 // TODO: unify the two ops?
1421 returnedArg = sparse_tensor::ConvertOp::create(
1422 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1423 else {
1424 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1425 resultType))
1426 return failure();
1427 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1428 resultType, returnedArg);
1429 }
1430 }
1431 returnedArgs.push_back(returnedArg);
1432 }
1433
1434 if (returnedArgs.size() != linalgOp->getNumResults())
1435 return failure();
1436 rewriter.replaceOp(linalgOp, returnedArgs);
1437 return success();
1438 }
1439};
1440
1441} // namespace
1442
1443void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1444 MLIRContext *context) {
1445 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1446}
1447
1448LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1449 return memref::foldMemRefCast(*this);
1450}
1451
1452//===----------------------------------------------------------------------===//
1453// MapOp
1454//===----------------------------------------------------------------------===//
1455
1456static ParseResult parseDstStyleOp(
1458 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1459 nullptr) {
1460 // Parse `ins` and `outs`.
1461 SmallVector<Type, 4> inputTypes, outputTypes;
1462 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1463 /*addOperandSegmentSizes=*/false))
1464 return failure();
1465
1466 // Add result types.
1467 for (Type outputType : outputTypes) {
1468 if (llvm::isa<RankedTensorType>(outputType))
1469 result.addTypes(outputType);
1470 }
1471
1472 // Parse required attributes.
1473 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1474 return failure();
1475
1476 // Parse optional attributes.
1477 if (parser.parseOptionalAttrDict(result.attributes))
1478 return failure();
1479 return success();
1480}
1481
1482void MapOp::getAsmBlockArgumentNames(Region &region,
1483 OpAsmSetValueNameFn setNameFn) {
1484 for (Value v : getRegionInputArgs())
1485 setNameFn(v, "in");
1486 for (Value v : getRegionOutputArgs())
1487 setNameFn(v, "init");
1488}
1489
1490void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1491 if (!getResults().empty())
1492 setNameFn(getResults().front(), "mapped");
1493}
1494
1495void MapOp::build(
1496 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1497 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1498 ArrayRef<NamedAttribute> attributes) {
1499 build(builder, result, TypeRange{}, inputs, init);
1500 result.addAttributes(attributes);
1501
1502 // Add output types for `RankedTensorType` output arguments.
1503 Type initType = init.getType();
1504 if (llvm::isa<RankedTensorType>(initType))
1505 result.addTypes(initType);
1506
1507 if (bodyBuild)
1508 buildGenericRegion(builder, result.location, *result.regions.front(),
1509 inputs, /*outputs=*/{init}, bodyBuild);
1510}
1511
1513 const OperationName &payloadOpName,
1514 const NamedAttrList &payloadOpAttrs,
1515 ArrayRef<Value> operands,
1516 bool initFirst = false, bool mapInit = true) {
1517 OpBuilder b(parser.getContext());
1518 Region *body = result.addRegion();
1519 Block &block = body->emplaceBlock();
1520 b.setInsertionPointToStart(&block);
1521 for (auto &operand : operands) {
1522 block.addArgument(
1523 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1524 b.getUnknownLoc());
1525 }
1526 SmallVector<Value> payloadOpOperands;
1527 // If initFirst flag is enabled, we consider init as the first position of
1528 // payload operands.
1529 if (initFirst) {
1530 if (mapInit)
1531 payloadOpOperands.push_back(block.getArguments().back());
1532 for (const auto &arg : block.getArguments().drop_back())
1533 payloadOpOperands.push_back(arg);
1534 } else {
1535 payloadOpOperands = {block.getArguments().begin(),
1536 block.getArguments().end() - int(!mapInit)};
1537 }
1538
1539 Operation *payloadOp = b.create(
1540 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1541 payloadOpOperands,
1542 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1543 .getElementType()},
1544 payloadOpAttrs);
1545 YieldOp::create(b, result.location, payloadOp->getResults());
1546}
1547
1548ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1549 std::optional<OperationName> payloadOpName;
1550 NamedAttrList payloadOpAttrs;
1551 if (succeeded(parser.parseOptionalLBrace())) {
1552 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1553 if (failed(operationName))
1554 return failure();
1555 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1556 return failure();
1557 payloadOpName = operationName.value();
1558 if (parser.parseRBrace())
1559 return failure();
1560 }
1561
1562 if (parseDstStyleOp(parser, result))
1563 return failure();
1564
1565 if (payloadOpName.has_value()) {
1566 if (!result.operands.empty())
1567 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1568 payloadOpAttrs, ArrayRef(result.operands), false,
1569 false);
1570 else
1571 result.addRegion();
1572 } else {
1573 SmallVector<OpAsmParser::Argument> regionArgs;
1574 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1575 /*allowType=*/true, /*allowAttrs=*/true)) {
1576 return failure();
1577 }
1578 Region *body = result.addRegion();
1579 if (parser.parseRegion(*body, regionArgs))
1580 return failure();
1581 }
1582 return success();
1583}
1584
1585static bool canUseShortForm(Block *body, bool initFirst = false,
1586 bool mapInit = true) {
1587 // `intFirst == true` implies that we want to map init arg
1588 if (initFirst && !mapInit)
1589 return false;
1590 // Check if the body can be printed in short form. The following 4 conditions
1591 // must be satisfied:
1592
1593 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1594 if (body->getOperations().size() != 2)
1595 return false;
1596 Operation &payload = body->getOperations().front();
1597
1598 // 2) The payload op must have the same number of operands as the number of
1599 // block arguments.
1600 if (payload.getNumOperands() == 0 ||
1601 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1602 return false;
1603
1604 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1605 // must be the first operand of the payload op, otherwise, the operands
1606 // must match the block arguments in order.
1607 if (initFirst) {
1608 // check init
1609 if (payload.getOperands().back() != body->getArgument(0))
1610 return false;
1611 // check rest
1612 for (const auto &[operand, bbArg] :
1613 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1614 if (bbArg != operand)
1615 return false;
1616 }
1617 } else {
1618 for (const auto &[operand, bbArg] :
1619 llvm::zip(payload.getOperands(),
1620 body->getArguments().drop_back(int(!mapInit)))) {
1621 if (bbArg != operand)
1622 return false;
1623 }
1624 }
1625
1626 // 4) The `yield` operand must be the result of the payload op.
1627 auto yieldOp = cast<YieldOp>(body->getTerminator());
1628 return yieldOp.getNumOperands() == 1 &&
1629 yieldOp.getOperand(0).getDefiningOp() &&
1630 yieldOp.getOperand(0).getDefiningOp() == &payload;
1631}
1632
1633static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1634 SmallVector<StringRef> elidedAttrs;
1635 std::string attrToElide;
1636 p << " { " << payloadOp->getName().getStringRef();
1637 for (const auto &attr : payloadOp->getAttrs()) {
1638 auto fastAttr =
1639 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1640 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1641 attrToElide = attr.getName().str();
1642 elidedAttrs.push_back(attrToElide);
1643 break;
1644 }
1645 }
1646 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1647 p << " }";
1648}
1649
1650void MapOp::print(OpAsmPrinter &p) {
1651 Block *mapper = getBody();
1652 bool useShortForm =
1653 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1654 if (useShortForm) {
1655 printShortForm(p, &mapper->getOperations().front());
1656 }
1657
1658 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1659 p.printOptionalAttrDict((*this)->getAttrs());
1660
1661 if (!useShortForm) {
1662 // Print region if the payload op was not detected.
1663 p.increaseIndent();
1664 p.printNewline();
1665 p << "(";
1666 llvm::interleaveComma(mapper->getArguments(), p,
1667 [&](auto arg) { p.printRegionArgument(arg); });
1668 p << ") ";
1669
1670 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1671 p.decreaseIndent();
1672 }
1673}
1674
1675LogicalResult MapOp::verify() {
1676 auto *bodyBlock = getBody();
1677 auto blockArgs = bodyBlock->getArguments();
1678
1679 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1680 // region.
1681 if (getInputs().size() + 1 != blockArgs.size())
1682 return emitOpError() << "expects number of operands to match the arity of "
1683 "mapper, but got: "
1684 << getInputs().size() + 1 << " and "
1685 << blockArgs.size();
1686
1687 // The parameters of mapper should all match the element type of inputs.
1688 for (const auto &[bbArgType, inputArg] :
1689 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1690 auto inputElemType =
1691 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1692 if (bbArgType != inputElemType) {
1693 return emitOpError() << "expected element type of input " << inputElemType
1694 << " to match bbArg type " << bbArgType;
1695 }
1696 }
1697
1698 // The shape of each input must match the shape of the output.
1699 auto outputShape = getInit().getType().getShape();
1700 for (Type inputArgType : TypeRange{getInputs()}) {
1701 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1702 if (inputElemShape != outputShape) {
1703 return emitOpError() << "expected shape of input (" << inputElemShape
1704 << ") to match shape of output (" << outputShape
1705 << ")";
1706 }
1707 }
1708
1709 return success();
1710}
1711
1712SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1713 int64_t rank = getInit().getType().getRank();
1714 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1715}
1716
1717ArrayAttr MapOp::getIndexingMaps() {
1718 Builder builder(getContext());
1719 int64_t rank = getInit().getType().getRank();
1720 int64_t numIndexingMaps = getOperands().size();
1721 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1722 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1723}
1724
1725void MapOp::getEffects(
1726 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1727 &effects) {
1728 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1729}
1730
1731Speculation::Speculatability MapOp::getSpeculatability() {
1732 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1733}
1734
1735//===----------------------------------------------------------------------===//
1736// ReduceOp
1737//===----------------------------------------------------------------------===//
1738
1739void ReduceOp::getAsmBlockArgumentNames(Region &region,
1740 OpAsmSetValueNameFn setNameFn) {
1741 for (Value v : getRegionInputArgs())
1742 setNameFn(v, "in");
1743 for (Value v : getRegionOutputArgs())
1744 setNameFn(v, "init");
1745}
1746
1747void ReduceOp::getAsmResultNames(
1748 function_ref<void(Value, StringRef)> setNameFn) {
1749 if (!getResults().empty())
1750 setNameFn(getResults().front(), "reduced");
1751}
1752
1753void ReduceOp::build(
1754 OpBuilder &builder, OperationState &result, ValueRange inputs,
1755 ValueRange inits, ArrayRef<int64_t> dimensions,
1756 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1757 ArrayRef<NamedAttribute> attributes) {
1758 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1759 result.addAttributes(attributes);
1760
1761 // Add output types for `RankedTensorType` output arguments.
1762 for (Value init : inits) {
1763 Type initType = init.getType();
1764 if (llvm::isa<RankedTensorType>(initType))
1765 result.addTypes(initType);
1766 }
1767
1768 if (bodyBuild)
1769 buildGenericRegion(builder, result.location, *result.regions.front(),
1770 inputs, inits, bodyBuild);
1771}
1772
1773SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1774 int64_t inputRank =
1775 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1776 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1777 utils::IteratorType::parallel);
1778 for (int64_t reductionDim : getDimensions())
1779 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1780 return iteratorTypes;
1781}
1782
1783ArrayAttr ReduceOp::getIndexingMaps() {
1784 int64_t inputRank =
1785 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1786 SmallVector<AffineMap> affineMaps(
1787 getNumDpsInputs(),
1789 AffineMap resultMap =
1791 .dropResults(getDimensions());
1792 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1793 affineMaps.push_back(resultMap);
1794 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1795}
1796
1797void ReduceOp::getEffects(
1798 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1799 &effects) {
1800 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1801}
1802
1803Speculation::Speculatability ReduceOp::getSpeculatability() {
1804 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1805}
1806
1807static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1808 NamedAttrList &attributes,
1809 StringRef attributeName) {
1810 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1811 return failure();
1812
1813 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1814 return success();
1815}
1816
1817ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1818 std::optional<OperationName> payloadOpName;
1819 NamedAttrList payloadOpAttrs;
1820 if (succeeded(parser.parseOptionalLBrace())) {
1821 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1822 if (failed(operationName))
1823 return failure();
1824 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1825 return failure();
1826 payloadOpName = operationName.value();
1827 if (parser.parseRBrace())
1828 return failure();
1829 }
1830
1831 if (parseDstStyleOp(
1832 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1833 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1834 }))
1835 return failure();
1836
1837 if (payloadOpName.has_value()) {
1838 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1839 ArrayRef(result.operands), /*initFirst=*/true);
1840 } else {
1841 SmallVector<OpAsmParser::Argument> regionArgs;
1842 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1843 /*allowType=*/true, /*allowAttrs=*/true)) {
1844 return failure();
1845 }
1846
1847 Region *body = result.addRegion();
1848 if (parser.parseRegion(*body, regionArgs))
1849 return failure();
1850 }
1851
1852 return success();
1853}
1854
1855static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1856 ArrayRef<int64_t> attributeValue) {
1857 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1858}
1859
1860void ReduceOp::print(OpAsmPrinter &p) {
1861 Block *mapper = getBody();
1862 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1863 if (useShortForm) {
1864 printShortForm(p, &mapper->getOperations().front());
1865 }
1866
1867 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1868 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1869 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1870 if (!useShortForm) {
1871 // Print region if the payload op was not detected.
1872 p.increaseIndent();
1873 p.printNewline();
1874 p << "(";
1875 llvm::interleaveComma(mapper->getArguments(), p,
1876 [&](auto arg) { p.printRegionArgument(arg); });
1877 p << ") ";
1878
1879 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1880 p.decreaseIndent();
1881 }
1882}
1883
1884LogicalResult ReduceOp::verify() {
1885 ArrayRef<int64_t> dimensionsRef = getDimensions();
1886
1887 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1888 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1889 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1890 return emitOpError() << "expects all inputs to have the same shapes. "
1891 "Shape at input-index "
1892 << i
1893 << " is not equal to the shape at input-index 0.";
1894 }
1895 }
1896 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1897 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1898 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1899 return emitOpError() << "expects all outputs to have the same shapes. "
1900 "Shape at output-index "
1901 << i
1902 << " is not equal to the shape at output-index 0.";
1903 }
1904 }
1905 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1906 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1907
1908 DenseSet<int64_t> dimensionsToReduce;
1909 for (int64_t dimension : dimensionsRef) {
1910 if (dimension < 0 || dimension >= inputType.getRank()) {
1911 return emitOpError()
1912 << "dimensions for reduction should be in the range [0, "
1913 << inputType.getRank() - 1 << "].";
1914 }
1915 dimensionsToReduce.insert(dimension);
1916 }
1917
1918 auto inputDims = inputType.getShape();
1919 auto initDims = initType.getShape();
1920
1921 // Input dimensions that will be left after the reduction.
1922 SmallVector<int64_t> reducedInputDims;
1923 for (const auto &en : llvm::enumerate(inputDims)) {
1924 if (!dimensionsToReduce.count(en.index()))
1925 reducedInputDims.push_back(en.value());
1926 }
1927
1928 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1929 return emitOpError() << "number of dimensions after reduction "
1930 << reducedInputDims.size()
1931 << " doesn't match the init rank "
1932 << initType.getRank();
1933 }
1934
1935 if (reducedInputDims != initDims)
1936 return emitOpError() << "init dimensions [" << initDims
1937 << "] doesn't match input dimensions after reduction ["
1938 << reducedInputDims << "]";
1939
1940 Block *block = getBody();
1941 if (block->getNumArguments() != this->getNumOperands())
1942 return emitOpError()
1943 << "mismatching number of operands and block arguments";
1944
1945 // Check that the first block arguments match the element type of the inputs.
1946 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1947 Type inputElementType =
1948 llvm::cast<ShapedType>(input.getType()).getElementType();
1949 if (inputElementType != bbArg.getType())
1950 return emitOpError()
1951 << "input element type " << inputElementType
1952 << " does not match corresponding block argument type "
1953 << bbArg.getType();
1954 }
1955
1956 // Check that the last block arguments match the element type of the outputs.
1957 for (auto [output, bbArg] : llvm::zip(
1958 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1959 auto outputElementType =
1960 llvm::cast<ShapedType>(output.getType()).getElementType();
1961 if (outputElementType != bbArg.getType())
1962 return emitOpError()
1963 << "output element type " << outputElementType
1964 << " does not match corresponding block argument type "
1965 << bbArg.getType();
1966 }
1967 return success();
1968}
1969
1970//===----------------------------------------------------------------------===//
1971// TransposeOp
1972//===----------------------------------------------------------------------===//
1973
1974static void buildIdentityRegion(OpBuilder &builder, Location loc,
1975 Region &region, ValueRange inputs,
1976 ValueRange outputs) {
1977 buildGenericRegion(builder, loc, region, inputs, outputs,
1978 [](OpBuilder &b, Location loc, ValueRange args) {
1979 if (!args.empty())
1980 linalg::YieldOp::create(b, loc, args[0]);
1981 });
1982}
1983
1984void TransposeOp::build(::mlir::OpBuilder &builder,
1985 ::mlir::OperationState &result, Value input, Value init,
1986 DenseI64ArrayAttr permutation,
1987 ArrayRef<NamedAttribute> attributes) {
1988 result.addOperands(input);
1989 result.addOperands(init);
1990 result.addAttribute(getPermutationAttrName(result.name), permutation);
1991 result.addAttributes(attributes);
1992
1993 // Add output types for `RankedTensorType` output arguments.
1994 Type initType = init.getType();
1995 if (llvm::isa<RankedTensorType>(initType))
1996 result.addTypes(initType);
1997
1998 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1999 init);
2000}
2001
2002void TransposeOp::build(::mlir::OpBuilder &builder,
2003 ::mlir::OperationState &result, Value input, Value init,
2004 ArrayRef<int64_t> permutation,
2005 ArrayRef<NamedAttribute> attributes) {
2006 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2007 attributes);
2008}
2009
2010ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2012 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2013 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2014 })))
2015 return failure();
2016
2017 OpBuilder builder(parser.getContext());
2018 buildIdentityRegion(builder, result.location, *result.addRegion(),
2019 /*inputs=*/result.operands,
2020 /*outputs=*/{});
2021 return success();
2022}
2023
2024void TransposeOp::getAsmResultNames(
2025 function_ref<void(Value, StringRef)> setNameFn) {
2026 if (!getResults().empty())
2027 setNameFn(getResults().front(), "transposed");
2028}
2029
2030void TransposeOp::print(OpAsmPrinter &p) {
2031 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2032 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2033 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2034}
2035
2036LogicalResult TransposeOp::verify() {
2037 ArrayRef<int64_t> permutationRef = getPermutation();
2038
2039 if (!isPermutationVector(permutationRef))
2040 return emitOpError("permutation is not valid");
2041
2042 auto inputType = getInput().getType();
2043 auto initType = getInit().getType();
2044
2045 int64_t rank = inputType.getRank();
2046
2047 if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
2048 "init")))
2049 return failure();
2050
2051 if (rank != static_cast<int64_t>(permutationRef.size()))
2052 return emitOpError() << "size of permutation " << permutationRef.size()
2053 << " does not match the argument rank " << rank;
2054
2055 auto inputDims = inputType.getShape();
2056 auto initDims = initType.getShape();
2057
2058 for (int64_t i = 0; i < rank; ++i) {
2059 int64_t inputDim = inputDims[permutationRef[i]];
2060 int64_t initDim = initDims[i];
2061
2062 if (inputDim != initDim) {
2063 return emitOpError() << "dim(result, " << i << ") = " << initDim
2064 << " doesn't match dim(input, permutation[" << i
2065 << "]) = " << inputDim;
2066 }
2067 }
2068
2069 return success();
2070}
2071
2072SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2073 int64_t rank = getInit().getType().getRank();
2074 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2075}
2076
2077ArrayAttr TransposeOp::getIndexingMaps() {
2078 Builder builder(getContext());
2079 int64_t rank = getInit().getType().getRank();
2080 return builder.getAffineMapArrayAttr(
2082 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2083 builder.getMultiDimIdentityMap(rank)});
2084}
2085
2086void TransposeOp::getEffects(
2087 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2088 &effects) {
2089 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2090}
2091
2092Speculation::Speculatability TransposeOp::getSpeculatability() {
2093 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2094}
2095
2096LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2097 SmallVectorImpl<OpFoldResult> &result) {
2098 // Only the tensor type is supported.
2099 if (!isa<TensorType>(getInput().getType()))
2100 return failure();
2101
2102 // Single dimension transpose.
2103 if (getPermutation().empty()) {
2104 result.push_back(getInput());
2105 return success();
2106 }
2107 // Identity permutation.
2108 if (isIdentityPermutation(getPermutation())) {
2109 result.push_back(getInput());
2110 return success();
2111 }
2112
2113 return failure();
2114}
2115
2116/// Fold transpose with transpose.
2117struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2118 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2119
2120 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2121 PatternRewriter &rewriter) const override {
2122 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2123 if (!defTransposeOp)
2124 return failure();
2125 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2126 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2127 SmallVector<int64_t> foldedPerms;
2128 foldedPerms.reserve(perms.size());
2129 for (int64_t perm : perms)
2130 foldedPerms.push_back(defPerms[perm]);
2131
2132 rewriter.replaceOpWithNewOp<TransposeOp>(
2133 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2134 foldedPerms);
2135 return success();
2136 }
2137};
2138
2139/// This pattern canonicalize transpose by swapping the order of
2140/// broadcast and transpose:
2141/// transpose(broadcast(input)) -> broadcast(transpose(input))
2142struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2143 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2144
2145 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2146 PatternRewriter &rewriter) const override {
2147 Value input = transposeOp.getInput();
2148 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2149 if (!input.hasOneUse() || !broadcastOp)
2150 return failure();
2151
2152 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2153 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2154
2155 // Get new perms and new dimensions.
2156 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2158 SmallVector<int64_t> resultDimensions;
2159 unsigned dimensionSize = dimensions.size();
2160 for (unsigned i = 0; i < dimensionSize; ++i)
2161 resultDimensions.push_back(invertPerm[dimensions[i]]);
2162
2163 // Create transpose result.
2164 Value broadcastInput = broadcastOp.getInput();
2165 Location loc = transposeOp.getLoc();
2166 MLIRContext *ctx = transposeOp.getContext();
2168 auto broadcastInputTy =
2169 mlir::cast<RankedTensorType>(broadcastInput.getType());
2170 unsigned inputRank = broadcastInputTy.getRank();
2171 for (unsigned i = 0; i < inputRank; ++i) {
2172 if (broadcastInputTy.isDynamicDim(i)) {
2173 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2174 ->getResult(0));
2175 } else {
2176 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2177 broadcastInputTy.getDimSize(i)));
2178 }
2179 }
2180 SmallVector<OpFoldResult> transposeResultShapes =
2181 applyPermutation(dims, resultPerms);
2182 Value transposeInit = tensor::EmptyOp::create(
2183 rewriter, transposeOp.getLoc(), transposeResultShapes,
2184 broadcastInputTy.getElementType());
2185
2186 // Create broadcast(transpose(input)).
2187 Value transposeResult =
2188 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2189 transposeInit, resultPerms)
2190 ->getResult(0);
2191 rewriter.replaceOpWithNewOp<BroadcastOp>(
2192 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2193 return success();
2194 }
2195};
2196
2197void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2198 MLIRContext *context) {
2199 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2200}
2201
2202//===----------------------------------------------------------------------===//
2203// BroadcastOp
2204//===----------------------------------------------------------------------===//
2205
2206void BroadcastOp::build(::mlir::OpBuilder &builder,
2207 ::mlir::OperationState &result, Value input, Value init,
2208 DenseI64ArrayAttr dimensions,
2209 ArrayRef<NamedAttribute> attributes) {
2210 result.addOperands(input);
2211 result.addOperands(init);
2212 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2213 result.addAttributes(attributes);
2214
2215 // Add output types for `RankedTensorType` output arguments.
2216 Type initType = init.getType();
2217 if (llvm::isa<RankedTensorType>(initType))
2218 result.addTypes(initType);
2219
2220 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2221 init);
2222}
2223
2224void BroadcastOp::build(::mlir::OpBuilder &builder,
2225 ::mlir::OperationState &result, Value input, Value init,
2226 ArrayRef<int64_t> dimensions,
2227 ArrayRef<NamedAttribute> attributes) {
2228 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2229 attributes);
2230}
2231
2232ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2234 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2235 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2236 })))
2237 return failure();
2238
2239 OpBuilder builder(parser.getContext());
2240 buildIdentityRegion(builder, result.location, *result.addRegion(),
2241 /*inputs=*/result.operands,
2242 /*outputs=*/{});
2243 return success();
2244}
2245
2246void BroadcastOp::getAsmResultNames(
2247 function_ref<void(Value, StringRef)> setNameFn) {
2248 if (!getResults().empty())
2249 setNameFn(getResults().front(), "broadcasted");
2250}
2251
2252void BroadcastOp::print(OpAsmPrinter &p) {
2253 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2254 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2255 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2256}
2257
2258LogicalResult BroadcastOp::verify() {
2259 ArrayRef<int64_t> dimensionsRef = getDimensions();
2260
2261 auto inputType = getInput().getType();
2262 auto initType = getInit().getType();
2263
2264 int64_t inputRank = inputType.getRank();
2265 int64_t initRank = initType.getRank();
2266
2267 auto inputShape = inputType.getShape();
2268 auto initShape = initType.getShape();
2269
2270 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2271 return emitOpError() << "input rank plus added dimensions does not "
2272 "match init rank. input rank: "
2273 << inputRank
2274 << ", dimensions size: " << dimensionsRef.size()
2275 << ", init rank: " << initRank;
2276
2277 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2278 if (dim < 0 || dim >= initRank)
2279 return emitOpError() << "dimension " << idx
2280 << " is out of range. expected range: [0, "
2281 << initRank - 1 << "], got: " << dim;
2282 }
2283
2284 // Mapping from input dims to init dims.
2285 SmallVector<int64_t> dimMap;
2286 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2287 if (!llvm::is_contained(dimensionsRef, dim))
2288 dimMap.push_back(dim);
2289 }
2290
2291 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2292 // This dimensions is mapped from the input. Init and input dims should
2293 // match.
2294 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2295 return emitOpError() << "input dim " << inputDimIdx
2296 << " should match init dim " << initDimIdx
2297 << ". input: " << inputShape[inputDimIdx]
2298 << ", init: " << initShape[initDimIdx];
2299 }
2300
2301 return success();
2302}
2303
2304SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2305 int64_t rank = getInit().getType().getRank();
2306 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2307}
2308
2309ArrayAttr BroadcastOp::getIndexingMaps() {
2310 Builder builder(getContext());
2311 int64_t rank = getInit().getType().getRank();
2312 return builder.getAffineMapArrayAttr(
2313 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2314 builder.getMultiDimIdentityMap(rank)});
2315}
2316
2317void BroadcastOp::getEffects(
2318 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2319 &effects) {
2320 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2321}
2322
2323Speculation::Speculatability BroadcastOp::getSpeculatability() {
2324 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2325}
2326
2327/// Fold back-to-back broadcasts together.
2328struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2329 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2330
2331 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2332 PatternRewriter &rewriter) const override {
2333 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2334 if (!defBroadcastOp)
2335 return failure();
2336 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2337 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2338 SmallVector<int64_t> foldedDims(dimensions);
2339 Value init = broadcastOp.getInit();
2340 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2341 // Mapping from input dims to init dims.
2342 SmallVector<int64_t> dimMap;
2343 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2344 if (!llvm::is_contained(dimensions, dim))
2345 dimMap.push_back(dim);
2346 }
2347 for (auto dim : defDimensions)
2348 foldedDims.push_back(dimMap[dim]);
2349
2350 llvm::sort(foldedDims);
2351 rewriter.replaceOpWithNewOp<BroadcastOp>(
2352 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2353 return success();
2354 }
2355};
2356
2357void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2358 MLIRContext *context) {
2359 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2360}
2361
2362//===----------------------------------------------------------------------===//
2363// YieldOp
2364//===----------------------------------------------------------------------===//
2365
2366void linalg::YieldOp::print(OpAsmPrinter &p) {
2367 if (getNumOperands() > 0)
2368 p << ' ' << getOperands();
2369 p.printOptionalAttrDict((*this)->getAttrs());
2370 if (getNumOperands() > 0)
2371 p << " : " << getOperandTypes();
2372}
2373
2374ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2375 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2376 SmallVector<Type, 2> types;
2377 SMLoc loc = parser.getCurrentLocation();
2378 return failure(parser.parseOperandList(opInfo) ||
2379 parser.parseOptionalAttrDict(result.attributes) ||
2380 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2381 parser.resolveOperands(opInfo, types, loc, result.operands));
2382}
2383
2384// Check the operand number and types must match the element types of the
2385// LinalgOp interface's shaped operands.
2386static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2387 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2388 return op.emitOpError("expected number of yield values (")
2389 << op.getNumOperands()
2390 << ") to match the number of inits / outs operands of the enclosing "
2391 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2392
2393 for (OpOperand &opOperand : op->getOpOperands()) {
2394 OpOperand *outputOperand =
2395 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2396 Type elementType = outputOperand->get().getType();
2397 if (isa<MemRefType, RankedTensorType>(elementType))
2398 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2399 if (opOperand.get().getType() != elementType)
2400 return op.emitOpError("type of yield operand ")
2401 << (opOperand.getOperandNumber() + 1) << " ("
2402 << opOperand.get().getType() << ") doesn't match "
2403 << "the element type of the enclosing linalg.generic op ("
2404 << elementType << ")";
2405 }
2406 return success();
2407}
2408
2409LogicalResult linalg::YieldOp::verify() {
2410 auto *parentOp = (*this)->getParentOp();
2411 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2412 return emitOpError("expected single non-empty parent region");
2413
2414 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2415 return verifyYield(*this, linalgOp);
2416
2417 return emitOpError("expected parent op with LinalgOp interface");
2418}
2419
2420//===----------------------------------------------------------------------===//
2421// IndexOp
2422//===----------------------------------------------------------------------===//
2423
2424LogicalResult IndexOp::verify() {
2425 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2426 if (!linalgOp)
2427 return emitOpError("expected parent op with LinalgOp interface");
2428 if (linalgOp.getNumLoops() <= getDim())
2429 return emitOpError("expected dim (")
2430 << getDim() << ") to be lower than the number of loops ("
2431 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2432 return success();
2433}
2434
2435OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2436 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2437 // Bail out if `linalg.index` does not have a proper parent yet at this
2438 // point, e.g., when calling `createOrFold` during IR construction in
2439 // `genericOp::build`.
2440 if (!linalgOp)
2441 return OpFoldResult{};
2442
2443 // Index of unit dims is always 0.
2444 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2445 uint64_t dim = getDim();
2446 assert(dim < loopBounds.size() && "Dim is out of bounds");
2447 if (loopBounds[dim] == 1)
2448 return IntegerAttr::get(IndexType::get(getContext()), 0);
2449
2450 return OpFoldResult{};
2451}
2452
2453/////// Operations corresponding to library calls defined with Tablegen ////////
2454
2455#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2456
2457#define GET_OP_CLASSES
2458#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2459
2460#define GET_OP_CLASSES
2461#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2462#define GET_OP_CLASSES
2463#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2464
2465AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2466 unsigned rank,
2467 MLIRContext *context) {
2468 if (maybeMap)
2469 return *maybeMap;
2470 if (rank == 0)
2471 return AffineMap::get(context);
2472 return AffineMap::getMultiDimIdentityMap(rank, context);
2473}
2474
2476mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2477 MLIRContext *context) {
2479 res.reserve(num);
2480 for (unsigned i = 0; i < num; ++i)
2481 res.push_back(getAffineDimExpr(startIdx++, context));
2482 return res;
2483}
2484
2487 auto rangeA = llvm::make_range(a.begin(), a.end());
2488 auto rangeB = llvm::make_range(b.begin(), b.end());
2489 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2490 return llvm::to_vector<4>(concatRanges);
2491}
2492
2493static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2494 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2495 ss << "view";
2496 for (auto size : memref.getShape())
2497 if (size < 0)
2498 ss << "sx";
2499 else
2500 ss << size << "x";
2501 if (failed(appendMangledType(ss, memref.getElementType())))
2502 return failure();
2503 if (auto as = memref.getMemorySpace()) {
2504 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2505 ss << "as" << attr.getInt();
2506 else
2507 return failure();
2508 }
2509 return success();
2510 }
2511 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2512 ss << "vector";
2513 llvm::interleave(
2514 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2515 if (failed(appendMangledType(ss, vec.getElementType())))
2516 return failure();
2517 return success();
2518 }
2520 ss << t;
2521 return success();
2522 }
2523 return failure();
2524}
2525
2527 assert(isa<LinalgOp>(op));
2528 std::string name(op->getName().getStringRef().str());
2529 std::string fun = "";
2530 for (NamedAttribute kv : op->getAttrs()) {
2531 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2532 fun = stringifyEnum(ufa.getValue()).str() + "_";
2533 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2534 fun = stringifyEnum(bfa.getValue()).str() + "_";
2535 }
2536 }
2537 name.reserve(128);
2538 llvm::replace(name, '.', '_');
2539 llvm::raw_string_ostream ss(name);
2540 ss << "_" << fun;
2541 for (Type t : op->getOperandTypes()) {
2542 if (failed(appendMangledType(ss, t)))
2543 return std::string();
2544 ss << "_";
2545 }
2546 name.pop_back();
2547 return name;
2548}
2549
2550//===----------------------------------------------------------------------===//
2551// Canonicalizers and Folders.
2552//===----------------------------------------------------------------------===//
2553
2554namespace {
2555struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2557
2558 LogicalResult matchAndRewrite(LinalgOp op,
2559 PatternRewriter &rewriter) const override {
2560 for (OpOperand &opOperand : op->getOpOperands()) {
2561 // Linalg "inputs" may be either tensor or memref type.
2562 // tensor<0xelt_type> is a convention that may not always mean
2563 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2564 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2565 if (!mt)
2566 continue;
2567 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2568 rewriter.eraseOp(op);
2569 return success();
2570 }
2571 }
2572 return failure();
2573 }
2574};
2575
2576/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2577/// result that is more static than the linalg op.
2578struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2579 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2580
2581 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2582 PatternRewriter &rewriter) const override {
2583 if (!tensor::canFoldIntoProducerOp(castOp))
2584 return failure();
2585
2586 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2587 if (!linalgOp)
2588 return failure();
2589
2590 // Cast can be in conditionally reachable region, if which case folding will
2591 // generate invalid code. Only conservatively fold ops in same block for
2592 // now.
2593 if (castOp->getBlock() != linalgOp->getBlock())
2594 return failure();
2595
2596 OpBuilder::InsertionGuard guard(rewriter);
2597 rewriter.setInsertionPoint(linalgOp);
2598
2599 Location loc = linalgOp.getLoc();
2600 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2601 unsigned resultNumber = resultValue.getResultNumber();
2602 auto resultType =
2603 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2604 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2605 // going from a more dynamic shape to a less dynamic shape. If the producer
2606 // for this cast, i.e. producer of the out operand, is also an operation
2607 // that folds with tensor.cast consumer (like this pattern), the cast will
2608 // continue to propagate as far up the stack as it can go.
2609 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2610 Value newOperand =
2611 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2612 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2613 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2614 linalgOp.getDpsInits().end());
2615 outputOperands[resultNumber] = newOperand;
2616 newOperands.append(outputOperands.begin(), outputOperands.end());
2617
2618 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2619 linalgOp->result_type_end());
2620 resultTypes[resultNumber] = resultType;
2621 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2622
2623 // Create a tensor.cast operation back to the original type.
2624 Value castBack = tensor::CastOp::create(
2625 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2626
2627 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2628 results[resultNumber] = castBack;
2629 rewriter.replaceOp(linalgOp, results);
2630 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2631 return success();
2632 }
2633};
2634
2635/// For each of the operand in `operands` this function maps the static sizes of
2636/// dimensions to their affine dim expressions.
2637static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2638 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2639 for (OpOperand &opOperand : operands) {
2640 if (linalgOp.isScalar(&opOperand))
2641 continue;
2642 Value src = opOperand.get();
2643 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2644 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2645
2646 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2647 // `tensor.cast` operation and source of the cast operation has a static
2648 // shape, then assign it to the `sourceShape`.
2649 auto *parentOp = src.getDefiningOp();
2650 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2651 if (parentOp) {
2652 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2653 Value castSource = castOp.getSource();
2654 auto castSourceType =
2655 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2656 if (castSourceType && castSourceType.hasStaticShape())
2657 sourceShape = castSourceType.getShape();
2658 }
2659 }
2660
2661 // If the source shape's dimension has a static shape, map the affine dim
2662 // expression to the known static size.
2663 for (unsigned i = 0; i < sourceShape.size(); i++) {
2664 if (sourceType.isDynamicDim(i))
2665 continue;
2666 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2667 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2668 }
2669 }
2670}
2671
2672/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2673/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2674/// their result types is stored in `resultTypes`. If `opOperand` requires no
2675/// change then `changeNeeded` is false and same operand is added in the
2676/// `newOperands` list.
2677static void createNewOperandWithStaticSizes(
2678 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2679 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2680 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2681 bool &changeNeeded) {
2682 Value src = opOperand->get();
2683 newOperands.push_back(src);
2684 if (linalgOp.isScalar(opOperand))
2685 return;
2686 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2687 Type resultType = sourceType;
2688 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2689 resultTypes.push_back(resultType);
2690 return;
2691 }
2692 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2693 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2694 SmallVector<int64_t> newShape;
2695 // If operand is updated with new shape, `newOperandNeeded` will be
2696 // true.
2697 bool newOperandNeeded = false;
2698 for (unsigned i = 0; i < sourceShape.size(); i++) {
2699 int64_t dimShape = sourceShape[i];
2700 AffineExpr dimExpr = sourceMap.getResult(i);
2701 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2702 newShape.push_back(dimShape);
2703 continue;
2704 }
2705 // Dimension has a dynamic shape and corresponding affine dim
2706 // expression is present in the map. So assign the size for the
2707 // given affine dim expression to the dimension.
2708 newShape.push_back(affineExprToSize[dimExpr]);
2709 newOperandNeeded = true;
2710 }
2711 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2712 sourceType.getEncoding());
2713 if (newOperandNeeded) {
2714 changeNeeded = true;
2715 // Get the new operand value given its size and element type by
2716 // casting it.
2717 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2718 unsigned index = opOperand->getOperandNumber();
2719 newOperands[index] = newOperand;
2720 }
2721 if (linalgOp.isDpsInit(opOperand))
2722 resultTypes.push_back(resultType);
2723}
2724
2725/// Static shapes for the operands can be inferred if any one of the operands
2726/// have a static shape. This can be done by referring to the affine dim
2727/// expressions for the operand.
2728struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2729 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2730
2731 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2732 PatternRewriter &rewriter) const override {
2733 if (!linalgOp.hasPureTensorSemantics())
2734 return failure();
2735
2736 // Maps must be projected permutations.
2737 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2738 return !map.isProjectedPermutation();
2739 }))
2740 return failure();
2741
2742 // Maps affine dim expressions to the static size of that dimension.
2743 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2744 Location loc = linalgOp.getLoc();
2745
2746 // For each of the affine dim expression, check if the size is known. If
2747 // known add that in the map.
2748 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2749
2750 SmallVector<Value> newOperands;
2751 SmallVector<Type> resultTypes;
2752
2753 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2754 // change in their types.
2755 bool changeNeeded = false;
2756 newOperands.reserve(linalgOp->getNumOperands());
2757 resultTypes.reserve(linalgOp.getNumDpsInits());
2758
2759 // Iterate over all the operands and update the static sizes.
2760 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2761 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2762 affineExprToSize, linalgOp, newOperands,
2763 resultTypes, changeNeeded);
2764 }
2765
2766 // If the generic op has all the required static information, no
2767 // canonicalization needed.
2768 if (!changeNeeded)
2769 return failure();
2770
2771 // Clone op.
2772 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2773 SmallVector<Value> replacements;
2774 replacements.reserve(newOp->getNumResults());
2775 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2776 Value newResult = std::get<1>(it);
2777 Value oldResult = std::get<0>(it);
2778 Type newType = newResult.getType();
2779 Type oldType = oldResult.getType();
2780 replacements.push_back(
2781 (newType != oldType)
2782 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2783 : newResult);
2784 }
2785 rewriter.replaceOp(linalgOp, replacements);
2786 return success();
2787 }
2788};
2789
2790} // namespace
2791
2792// All named ops canonicalizers and folders are auto-generated in the
2793// .cpp.inc.
2794
2795//===----------------------------------------------------------------------===//
2796// SoftmaxOp
2797//===----------------------------------------------------------------------===//
2798
2799LogicalResult SoftmaxOp::verify() {
2800 ShapedType inputType = getInputOperandType();
2801 ShapedType outputType = getOutputOperandType();
2802
2803 ArrayRef<int64_t> inputShape = inputType.getShape();
2804 ArrayRef<int64_t> outputShape = outputType.getShape();
2805 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2806 return emitOpError("incompatible output shape");
2807
2808 int64_t inputRank = getInputOperandRank();
2809 int64_t dimension = getDimension();
2810 if ((dimension < 0) || (dimension >= inputRank))
2811 return emitOpError("incorrect dimension specified");
2812
2813 return success();
2814}
2815
2816SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2817 int64_t operandRank = getInputOperandRank();
2818 SmallVector<Range> loopBounds(operandRank);
2819 Location loc = getLoc();
2820 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2821 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2822 Value source = getInput();
2823 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2824 loopBounds[dim].offset = zero;
2825 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2826 loopBounds[dim].stride = one;
2827 }
2828 return loopBounds;
2829}
2830
2831SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2832 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2833 utils::IteratorType::parallel);
2834 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2835 return iteratorTypes;
2836}
2837
2838FailureOr<TilingResult>
2839SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2840 ArrayRef<OpFoldResult> offsets,
2841 ArrayRef<OpFoldResult> sizes) {
2842 int64_t rank = getInputOperandRank();
2843 auto oneAttr = builder.getI64IntegerAttr(1);
2844 SmallVector<OpFoldResult> strides(rank, oneAttr);
2845 SmallVector<Value> tiledOperands;
2846 Operation *inputSlice =
2847 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2848 if (!inputSlice) {
2849 return emitOpError("failed to compute input slice");
2850 }
2851 tiledOperands.emplace_back(inputSlice->getResult(0));
2852 Operation *outputSlice =
2853 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2854 if (!outputSlice) {
2855 return emitOpError("failed to compute output slice");
2856 }
2857 tiledOperands.emplace_back(outputSlice->getResult(0));
2858
2859 SmallVector<Type, 4> resultTypes;
2860 if (hasPureTensorSemantics())
2861 resultTypes.push_back(tiledOperands[1].getType());
2862 Operation *tiledOp =
2863 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2864
2865 return TilingResult{
2866 {tiledOp},
2867 SmallVector<Value>(tiledOp->getResults()),
2868 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2869}
2870
2871LogicalResult SoftmaxOp::getResultTilePosition(
2872 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2873 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2874 SmallVector<OpFoldResult> &resultSizes) {
2875 if (resultNumber == 0) {
2876 resultOffsets.assign(offsets.begin(), offsets.end());
2877 resultSizes.assign(sizes.begin(), sizes.end());
2878 return success();
2879 }
2880 return failure();
2881}
2882
2883// cast(dynamic) -> static.
2884LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2885 return memref::foldMemRefCast(*this);
2886}
2887
2888LogicalResult
2889SoftmaxOp::reifyResultShapes(OpBuilder &b,
2890 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2891 SmallVector<OpFoldResult> shapes;
2892 Location loc = getOperation()->getLoc();
2893 IRRewriter rewriter(b);
2894 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2895 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2896 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2897 if (!outputShapedType.isDynamicDim(dim)) {
2898 // Static dim: Return IntegerAttr.
2899 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2900 } else {
2901 // Dynamic dim: Return Value.
2902 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2903 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2904 }
2905 }
2906 reifiedReturnShapes.emplace_back(std::move(shapes));
2907 return success();
2908}
2909
2910void SoftmaxOp::getEffects(
2911 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2912 &effects) {
2913 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2914 if (!llvm::isa<MemRefType>(operand.getType()))
2915 continue;
2916 effects.emplace_back(MemoryEffects::Read::get(),
2917 &getOperation()->getOpOperand(index), /*stage=*/0,
2918 /*effectOnFullRegion=*/true,
2920 }
2921
2922 for (OpOperand &operand : getDpsInitsMutable()) {
2923 if (!llvm::isa<MemRefType>(operand.get().getType()))
2924 continue;
2925 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2926 /*effectOnFullRegion=*/true,
2928 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2929 /*effectOnFullRegion=*/true,
2931 }
2932}
2933
2934// Helper functions for softmax decomposition.
2935// @{
2936
2937// Helper function to produce the iterator types (reduction or parallel) and
2938// affine maps for the iterators used in the decomposition of softmax.
2939// This method creates:
2940// If allParallel == true:
2941// - iterator type: {parallel, ..., parallel}
2942// - affine maps:
2943// -- identity with inputRank dimensions.
2944// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2945// where N == inputRank.
2946//
2947// If allParallel == false:
2948// - iterator type at dim(i) == parallel for i != \p dim and
2949// dim(dim) == reduction.
2950// - affine map:
2951// -- identity with inputRank dimensions.
2952// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2953// where N == inputRank.
2954static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2956 int64_t dim, bool allParallel = false) {
2957 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2958 utils::IteratorType::parallel);
2959 if (!allParallel)
2960 iteratorTypes[dim] = utils::IteratorType::reduction;
2961 MLIRContext *ctxt = builder.getContext();
2962 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2963 SmallVector<AffineExpr, 2> affineExprs;
2964 for (int i = 0; i < inputRank; i++) {
2965 if (i != dim)
2966 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2967 }
2968 auto reductionMap =
2969 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2970 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2971 return std::make_tuple(iteratorTypes, indexingMaps);
2972}
2973
2974// Helper function to produce a linalg.generic that computes a reduction on
2975// dimension \p dim with the operation type \p T.
2976template <typename T>
2977static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2978 int64_t dim) {
2979 auto inputType = cast<ShapedType>(input.getType());
2980 ArrayRef<int64_t> inputShape = inputType.getShape();
2981 int64_t inputRank = inputShape.size();
2982 auto [iteratorTypes, indexingMaps] =
2983 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2984 assert(indexingMaps.size() == 2 &&
2985 "We should have two maps: 1 for the input, 1 for the output");
2986 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2987
2988 auto genericOp = linalg::GenericOp::create(
2989 builder, loc, output.getType(), input, output, indexingMaps,
2990 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2991 Value result = T::create(b, loc, args[0], args[1]);
2992 linalg::YieldOp::create(b, loc, result);
2993 });
2994 return genericOp.getResult(0);
2995}
2996
2997/// Produce a linalg generic that computes the second step of the softmax
2998/// decomposition: res = exp(input - max), where \p max is the max of \p input
2999/// on dimension \p dim.
3000static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3001 Value max, Value output, int64_t dim) {
3002 auto inputType = cast<ShapedType>(input.getType());
3003 ArrayRef<int64_t> inputShape = inputType.getShape();
3004 int64_t inputRank = inputShape.size();
3005 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3006 builder, inputRank, dim, /*allParallel=*/true);
3007 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3008 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3009 // Add the affine map for the output argument.
3010 indexingMaps.push_back(indexingMaps[0]);
3011 auto genericOp = linalg::GenericOp::create(
3012 builder, loc, input.getType(), ValueRange{input, max}, output,
3013 indexingMaps, iteratorTypes,
3014 [&](OpBuilder &b, Location loc, ValueRange args) {
3015 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3016 Value result = math::ExpOp::create(b, loc, diff);
3017 linalg::YieldOp::create(b, loc, result);
3018 });
3019 return genericOp.getResult(0);
3020}
3021
3022/// Produce a linalg generic that computes the final step of the softmax
3023/// decomposition.
3024/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3025/// yield n / d
3026/// }
3027static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3028 Value denominator, Value output, int64_t dim) {
3029 auto inputType = cast<ShapedType>(numerator.getType());
3030 ArrayRef<int64_t> inputShape = inputType.getShape();
3031 int64_t inputRank = inputShape.size();
3032 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3033 builder, inputRank, dim, /*allParallel=*/true);
3034 assert(indexingMaps.size() == 2 &&
3035 "We should have one map for each input (2)");
3036 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3037 // Add the affine map for the output tensor.
3038 indexingMaps.push_back(indexingMaps[0]);
3039 auto genericOp = linalg::GenericOp::create(
3040 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3041 output, indexingMaps, iteratorTypes,
3042 [&](OpBuilder &b, Location loc, ValueRange args) {
3043 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3044 linalg::YieldOp::create(b, loc, result);
3045 });
3046 return genericOp.getResult(0);
3047}
3048// @} End helper functions for softmax decomposition.
3049
3050/// Given an N-dimensional tensor x, this method converts
3051/// softmax(x) to the following sequence of operations:
3052///
3053/// 1. Compute the max of x along dimension d. This results
3054/// in a N-1 dimensional tensor m.
3055/// m = max(x, dim = d)
3056///
3057/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3058/// a N dimensional tensor z.
3059/// z = exp(x - m)
3060///
3061/// 3. Compute the sum of z along dimension d. This results in
3062/// a N-1 dimensional tensor l.
3063/// l = sum(z, dim = d)
3064///
3065/// 4. Divide z and l. This gives the N-dimensional softmax.
3066/// softmax = z / l
3067///
3068FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3069 OpBuilder::InsertionGuard guard(b);
3070 b.setInsertionPoint(*this);
3071 Location loc = getLoc();
3072 Value input = getInput();
3073 ShapedType inputType = getInputOperandType();
3074 Type elementType = inputType.getElementType();
3075 int64_t reductionDim = getDimension();
3076 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3077 Value output = getOutput();
3078 dims.erase(dims.begin() + reductionDim);
3079 // Step 1: Compute max along dim.
3080 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3081 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3082 elementType, b, loc,
3083 /*useOnlyFiniteValue=*/true);
3084 Value neutralForMaxFInit =
3085 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3086 .result();
3087 Value max =
3088 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3089
3090 // Step 2: Subtract max from input and exponentiate.
3091 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3092
3093 // Step 3: Compute sum along dim.
3094 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3095 b, loc, /*useOnlyFiniteValue=*/true);
3096 Value zeroInit =
3097 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3098 Value denominator =
3099 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3100
3101 // Step 4: Compute softmax.
3102 Value result =
3103 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3104 return SmallVector<Value>{result};
3105}
3106
3107//===----------------------------------------------------------------------===//
3108// WinogradFilterTransformOp
3109//===----------------------------------------------------------------------===//
3110
3111LogicalResult WinogradFilterTransformOp::verify() {
3112 auto filterType = cast<ShapedType>(getFilter().getType());
3113 ArrayRef<int64_t> filterShape = filterType.getShape();
3114 int64_t filterH = filterShape[getFilterHDim()];
3115 int64_t filterW = filterShape[getFilterWDim()];
3116 WinogradConv2DFmr fmr = getFmr();
3117 int64_t m, r;
3118 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3119
3120 if (filterH != r && filterH != 1)
3121 return emitOpError("expect filter height either equals to r or 1");
3122 if (filterW != r && filterW != 1)
3123 return emitOpError("expect filter width either equals to r or 1");
3124 if (filterH == 1 && filterW == 1)
3125 return emitOpError("expect either filter height or width equals to r");
3126
3127 SmallVector<int64_t> expectedOutputShape;
3128 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3130 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3131 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3132
3133 auto outputType = cast<ShapedType>(getOutput().getType());
3134 ArrayRef<int64_t> outputShape = outputType.getShape();
3135 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3136 return emitOpError("the output shape is not expected");
3137 }
3138 return success();
3139}
3140
3141SmallVector<Range>
3142WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3143 Location loc = getLoc();
3144 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3145 IntegerAttr oneAttr = builder.getIndexAttr(1);
3146 Value filter = getFilter();
3147 int64_t filterRank = getFilterOperandRank();
3148 SmallVector<Range> loopBounds(filterRank);
3149 for (unsigned dim = 0; dim < filterRank; ++dim) {
3150 loopBounds[dim].offset = zeroAttr;
3151 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3152 loopBounds[dim].stride = oneAttr;
3153 }
3154 return loopBounds;
3155}
3156
3157SmallVector<utils::IteratorType>
3158WinogradFilterTransformOp::getLoopIteratorTypes() {
3159 int64_t filterRank = getFilterOperandRank();
3160 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3161 utils::IteratorType::parallel);
3162 return iteratorTypes;
3163}
3164
3165LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3166 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3167 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3168 SmallVector<OpFoldResult> &resultSizes) {
3169 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3170 ShapedType filterType = getFilterOperandType();
3171 ArrayRef<int64_t> filterShape = filterType.getShape();
3172 int64_t filterH = filterShape[getFilterHDim()];
3173 int64_t filterW = filterShape[getFilterWDim()];
3174 WinogradConv2DFmr fmr = getFmr();
3175 int64_t m, r;
3176 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3177 int64_t alpha = m + r - 1;
3178 int64_t alphaH = filterH != 1 ? alpha : 1;
3179 int64_t alphaW = filterW != 1 ? alpha : 1;
3180 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3181 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3182
3183 resultOffsets.append(
3184 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3185 resultSizes.append(
3186 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3187
3188 return success();
3189}
3190
3191/// Implement tiling for winograd_filter_transform
3192/// The input of winograd_filter_transform is (F, KH, KW, C).
3193/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3194/// Users can specify the tile sizes of F and C.
3195/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3196/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3197FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3198 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3199 ArrayRef<OpFoldResult> sizes) {
3200 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3201 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3202 ShapedType filterType = getFilterOperandType();
3203 ArrayRef<int64_t> filterShape = filterType.getShape();
3204 int64_t filterH = filterShape[getFilterHDim()];
3205 int64_t filterW = filterShape[getFilterWDim()];
3206 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3207 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3208 SmallVector<Value> tiledOperands;
3209 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3210
3211 sliceOffsets.append(
3212 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3213 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3214 sizes[getFilterCDim()]});
3215 int64_t filterRank = getFilterOperandRank();
3216 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3217 Location loc = getLoc();
3218 auto filterSlice = tensor::ExtractSliceOp::create(
3219 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3220 tiledOperands.emplace_back(filterSlice);
3221
3222 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3223 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3224 resultSizes)))
3225 return failure();
3226
3227 int64_t outputRank = getOutputOperandRank();
3228 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3229 auto outputSlice = tensor::ExtractSliceOp::create(
3230 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3231 tiledOperands.emplace_back(outputSlice);
3232
3233 SmallVector<Type> resultTypes;
3234 resultTypes.push_back(tiledOperands[1].getType());
3235 Operation *tiledOp =
3236 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3237
3238 return TilingResult{
3239 {tiledOp},
3240 SmallVector<Value>(tiledOp->getResults()),
3241 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3242}
3243
3244//===----------------------------------------------------------------------===//
3245// WinogradInputTransformOp
3246//===----------------------------------------------------------------------===//
3247
3248LogicalResult WinogradInputTransformOp::verify() {
3249 auto inputType = cast<ShapedType>(getInput().getType());
3250 ArrayRef<int64_t> inputShape = inputType.getShape();
3251 int64_t inputH = inputShape[getInputHDim()];
3252 int64_t inputW = inputShape[getInputWDim()];
3253 WinogradConv2DFmr fmr = getFmr();
3254 int64_t m, r;
3255 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3256 int64_t tileSize = m + r - 1;
3257
3258 auto outputType = cast<ShapedType>(getOutput().getType());
3259 ArrayRef<int64_t> outputShape = outputType.getShape();
3260 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3261 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3262
3263 SmallVector<int64_t> expectedOutputShape(6, inputH);
3264 if (ShapedType::isDynamic(inputH)) {
3265 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3266 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3267 } else {
3268 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3269 expectedOutputShape[getOutputTileHDim()] =
3270 leftTransform ? (inputH - (r - 1)) / m : inputH;
3271 }
3272 if (ShapedType::isDynamic(inputW)) {
3273 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3274 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3275 } else {
3276 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3277 expectedOutputShape[getOutputTileWDim()] =
3278 rightTransform ? (inputW - (r - 1)) / m : inputW;
3279 }
3280 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3281 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3282
3283 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3284 return emitOpError("the output shape is not expected");
3285 }
3286 return success();
3287}
3288
3289SmallVector<Range>
3290WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3291 Location loc = getLoc();
3292 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3293 IntegerAttr oneAttr = builder.getIndexAttr(1);
3294 Value output = getOutput();
3295 int64_t outputRank = getOutputOperandRank();
3296 SmallVector<Range> loopBounds(outputRank);
3297 for (unsigned dim = 0; dim < outputRank; ++dim) {
3298 loopBounds[dim].offset = zeroAttr;
3299 // alphaH, alphaW, tileH, tileW, N, C
3300 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3301 loopBounds[dim].stride = oneAttr;
3302 }
3303 return loopBounds;
3304}
3305
3306SmallVector<utils::IteratorType>
3307WinogradInputTransformOp::getLoopIteratorTypes() {
3308 int64_t outputRank = getOutputOperandRank();
3309 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3310 utils::IteratorType::parallel);
3311 return iteratorTypes;
3312}
3313
3314LogicalResult WinogradInputTransformOp::getResultTilePosition(
3315 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3316 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3317 SmallVector<OpFoldResult> &resultSizes) {
3318 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3319 ShapedType outputType = getOutputOperandType();
3320 ArrayRef<int64_t> outputShape = outputType.getShape();
3321 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3322 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3323
3324 WinogradConv2DFmr fmr = getFmr();
3325 int64_t m, r;
3326 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3327 int64_t alpha = m + r - 1;
3328 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3329 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3330
3331 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3332 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3333
3334 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3335 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3336 offsets[getOutputCDim()]});
3337 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3338 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3339 sizes[getOutputCDim()]});
3340
3341 return success();
3342}
3343
3344/// Implement tiling for winograd_input_transform
3345/// The input of winograd_input_transform is (N, H, W, C).
3346/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3347/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3348/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3349/// the values for the sizes of tileH, tileW, N, C for one tile.
3350FailureOr<TilingResult>
3351WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3352 ArrayRef<OpFoldResult> offsets,
3353 ArrayRef<OpFoldResult> sizes) {
3354 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3355 WinogradConv2DFmr fmr = getFmr();
3356 int64_t m, r;
3357 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3358
3359 ShapedType outputType = getOutputOperandType();
3360 ArrayRef<int64_t> outputShape = outputType.getShape();
3361 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3362 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3363
3364 Location loc = getLoc();
3365 MLIRContext *context = builder.getContext();
3366 auto identityAffineMap =
3367 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3368 auto offsetAffineMap =
3369 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3370 Value mappedOffsetH = affine::makeComposedAffineApply(
3371 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3372 offsets[getOutputTileHDim()]);
3373 Value mappedOffsetW = affine::makeComposedAffineApply(
3374 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3375 offsets[getOutputTileWDim()]);
3376 auto sizeAffineMap = AffineMap::get(
3377 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3378 Value mappedSizeH = affine::makeComposedAffineApply(
3379 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3380 Value mappedSizeW = affine::makeComposedAffineApply(
3381 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3382
3383 SmallVector<Value> tiledOperands;
3384 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3385
3386 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3387 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3388 sliceOffsets.append(
3389 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3390 OpFoldResult sizeH =
3391 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3392 OpFoldResult sizeW =
3393 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3394 sliceSizes.append(
3395 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3396 int64_t inputRank = getInputOperandRank();
3397 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3398 auto inputSlice = tensor::ExtractSliceOp::create(
3399 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3400 tiledOperands.emplace_back(inputSlice);
3401
3402 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3403 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3404 resultSizes)))
3405 return failure();
3406
3407 int64_t outputRank = getOutputOperandRank();
3408 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3409 auto outputSlice = tensor::ExtractSliceOp::create(
3410 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3411 tiledOperands.emplace_back(outputSlice);
3412
3413 SmallVector<Type> resultTypes;
3414 resultTypes.push_back(tiledOperands[1].getType());
3415 Operation *tiledOp =
3416 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3417
3418 return TilingResult{
3419 {tiledOp},
3420 SmallVector<Value>(tiledOp->getResults()),
3421 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3422}
3423
3424//===----------------------------------------------------------------------===//
3425// WinogradOutputTransformOp
3426//===----------------------------------------------------------------------===//
3427
3428LogicalResult WinogradOutputTransformOp::verify() {
3429 auto valueType = cast<ShapedType>(getValue().getType());
3430 ArrayRef<int64_t> valueShape = valueType.getShape();
3431 int64_t valueH = valueShape[getValueAlphaHDim()];
3432 int64_t valueW = valueShape[getValueAlphaWDim()];
3433 int64_t valueTileH = valueShape[getValueTileHDim()];
3434 int64_t valueTileW = valueShape[getValueTileWDim()];
3435 WinogradConv2DFmr fmr = getFmr();
3436 int64_t m, r;
3437 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3438 bool leftTransform = valueH != 1;
3439 bool rightTransform = valueW != 1;
3440
3441 int64_t outputRank = getOutputOperandRank();
3442 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3443 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3444 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3445 } else {
3446 if (valueH != (leftTransform ? m + r - 1 : 1))
3447 return emitOpError("expect input height equals to input tile size");
3448 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3449 }
3450 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3451 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3452 } else {
3453 if (valueW != (rightTransform ? m + r - 1 : 1))
3454 return emitOpError("expect input width equals to input tile size");
3455 expectedOutputShape[getOutputWDim()] =
3456 (rightTransform ? m : 1) * valueTileW;
3457 }
3458 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3459 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3460
3461 auto outputType = cast<ShapedType>(getOutput().getType());
3462 ArrayRef<int64_t> outputShape = outputType.getShape();
3463 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3464 return emitOpError("the output shape is not expected");
3465 }
3466 return success();
3467}
3468
3469SmallVector<Range>
3470WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3471 Location loc = getLoc();
3472 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3473 IntegerAttr oneAttr = builder.getIndexAttr(1);
3474 Value value = getValue();
3475 int64_t valueRank = getValueOperandRank();
3476 SmallVector<Range> loopBounds(valueRank);
3477 for (unsigned dim = 0; dim < valueRank; ++dim) {
3478 loopBounds[dim].offset = zeroAttr;
3479 // alphaH, alphaW, tileH, tileW, N, F
3480 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3481 loopBounds[dim].stride = oneAttr;
3482 }
3483 return loopBounds;
3484}
3485
3486SmallVector<utils::IteratorType>
3487WinogradOutputTransformOp::getLoopIteratorTypes() {
3488 int64_t valueRank = getValueOperandRank();
3489 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3490 utils::IteratorType::parallel);
3491 return iteratorTypes;
3492}
3493
3494LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3495 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3496 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3497 SmallVector<OpFoldResult> &resultSizes) {
3498 WinogradConv2DFmr fmr = getFmr();
3499 int64_t m, r;
3500 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3501
3502 Location loc = getLoc();
3503 MLIRContext *context = builder.getContext();
3504 auto identityAffineMap =
3505 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3506 auto affineMap =
3507 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3508
3509 ShapedType valueType = getValueOperandType();
3510 ArrayRef<int64_t> valueShape = valueType.getShape();
3511 int64_t valueH = valueShape[0];
3512 int64_t valueW = valueShape[1];
3513 Value mappedOffsetH = affine::makeComposedAffineApply(
3514 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3515 offsets[getValueTileHDim()]);
3516 Value mappedOffsetW = affine::makeComposedAffineApply(
3517 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3518 offsets[getValueTileWDim()]);
3519 Value mappedSizeH = affine::makeComposedAffineApply(
3520 builder, loc, affineMap, sizes[getValueTileHDim()]);
3521 Value mappedSizeW = affine::makeComposedAffineApply(
3522 builder, loc, affineMap, sizes[getValueTileWDim()]);
3523
3524 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3525 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3526 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3527 OpFoldResult sizeH =
3528 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3529 OpFoldResult sizeW =
3530 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3531
3532 resultOffsets.append(
3533 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3534 resultSizes.append(
3535 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3536 return success();
3537}
3538
3539/// Implement tiling for winograd_output_transform
3540/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3541/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3542/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3543/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3544/// for the sizes of tileH, tileW, N, F for one tile.
3545FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3546 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3547 ArrayRef<OpFoldResult> sizes) {
3548 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3549 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3550 Location loc = getLoc();
3551 SmallVector<Value> tiledOperands;
3552 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3553
3554 ShapedType valueType = getValueOperandType();
3555 ArrayRef<int64_t> valueShape = valueType.getShape();
3556 int64_t alphaH = valueShape[getValueAlphaHDim()];
3557 int64_t alphaW = valueShape[getValueAlphaWDim()];
3558 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3559 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3560
3561 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3562 offsets[getValueTileWDim()], offsets[getValueNDim()],
3563 offsets[getValueFDim()]});
3564 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3565 sizes[getValueTileWDim()], sizes[getValueNDim()],
3566 sizes[getValueFDim()]});
3567 int64_t valueRank = getValueOperandRank();
3568 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3569 auto valueSlice = tensor::ExtractSliceOp::create(
3570 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3571 tiledOperands.emplace_back(valueSlice);
3572
3573 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3574 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3575 resultSizes)))
3576 return failure();
3577
3578 int64_t outputRank = getOutputOperandRank();
3579 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3580 auto outputSlice = tensor::ExtractSliceOp::create(
3581 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3582 tiledOperands.emplace_back(outputSlice);
3583
3584 SmallVector<Type> resultTypes;
3585 resultTypes.push_back(tiledOperands[1].getType());
3586 Operation *tiledOp =
3587 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3588
3589 return TilingResult{
3590 {tiledOp},
3591 SmallVector<Value>(tiledOp->getResults()),
3592 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3593}
3594
3595//===----------------------------------------------------------------------===//
3596// LinalgDialect
3597// TODO: Merge with the LinalgDialect block at the bottom
3598//===----------------------------------------------------------------------===//
3599
3600// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3601static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3602 auto explicitRange = subMap.getResults();
3603 auto defaultRange = fullMap.getResults();
3604 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3605 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3606 llvm::set_union(explicitSet, defaultSet);
3607 return explicitSet == defaultSet;
3608}
3609
3610/// Check if the user defined map is valid broadcast map. Here broadcast
3611/// indexing maps are defined in context of corresponding default indexing maps
3612/// for the given Op. This way the check becomes very simple i.e just check the
3613/// number of result dims.
3614/// Returns true if the explictMap is broadcasted with respect to the
3615/// defaultMap.
3616static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3617 return explictMap.getNumResults() < defaultMap.getNumResults();
3618}
3619
3620/// Verifies the broadcast and transpose semantic sepecified by the explicit
3621/// indexing map for the MatmulOp \p op for each operand specified by \p
3622/// opIndex.
3623static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3624 unsigned opIndex) {
3625 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3626 SmallVector<AffineMap, 3> defaultIndexingMaps =
3627 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3628
3629 auto opIndexingMap = opIndexingMaps[opIndex];
3630 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3631 // Check general validity of indexing map results.
3632 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3633 return matmulOp->emitOpError()
3634 << "Unexpected dim expression in map result.";
3635
3636 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3637 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3638 return matmulOp->emitOpError()
3639 << "Invalid broadcast requested, should be (d2).";
3640 }
3641 return success();
3642 }
3643 return success();
3644}
3645
3646// Check general validity of input indexing map of
3647// BatchMatmulOp/BatchReduceMatmulOp.
3648template <typename OpTy>
3649static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3650 AffineMap opIndexingMap,
3651 AffineMap defaultIndexingMap, bool isLHS) {
3652 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3653 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3654 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3655 // Check the result dims are valid.
3656 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3657 return batchVariantMatmulOp->emitOpError()
3658 << "Unexpected result dim expression (outside the set of default "
3659 "result dims).";
3660
3661 // Check for valid number of result dims of input maps.
3662 if (opIndexingMap.getNumResults() > 3)
3663 return batchVariantMatmulOp->emitOpError()
3664 << "no. of result dim expressions exceeds 3.";
3665
3666 auto hasValidBatchDim = [](AffineMap map) {
3667 AffineExpr batchDim = map.getResult(0);
3668 return batchDim.isFunctionOfDim(0);
3669 };
3670
3671 // Check if the requested broadcast is valid.
3672 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3673 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3674 return batchVariantMatmulOp->emitOpError()
3675 << "Invalid broadcast requested.";
3676 } else if (!hasValidBatchDim(opIndexingMap)) {
3677 return batchVariantMatmulOp->emitOpError()
3678 << "Invalid batch dimension expression.";
3679 }
3680 return success();
3681}
3682
3683/// This function checks if the given AffineMap for the output of a
3684/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3685/// dimensions and if the output map result dimensions are valid.
3686template <typename OpTy>
3687static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3688 AffineMap opIndexingMap) {
3689 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3690 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3691 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3692 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3693 opIndexingMap.getNumResults() != 3) {
3694
3695 return batchVariantMatmulOp->emitOpError()
3696 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3697 << ").";
3698 }
3699 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3700 opIndexingMap.getNumResults() != 2) {
3701 return batchVariantMatmulOp->emitOpError()
3702 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3703 << ").";
3704 }
3705
3706 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3707 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3708 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3709 outputMap.getResult(1).isFunctionOfDim(1) &&
3710 outputMap.getResult(2).isFunctionOfDim(2)
3711 : outputMap.getResult(0).isFunctionOfDim(1) &&
3712 outputMap.getResult(1).isFunctionOfDim(2);
3713 };
3714
3715 if (!areValidOutputResultDim(opIndexingMap)) {
3716 return batchVariantMatmulOp->emitOpError()
3717 << "Invalid output map result dimension.";
3718 }
3719
3720 return success();
3721}
3722
3723/// Verifies the broadcast and transpose semantic specified by the explicit
3724/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3725/// specified by opIndex.
3726template <typename OpTy>
3727static LogicalResult
3729 unsigned opIndex) {
3730 SmallVector<AffineMap, 3> opIndexingMaps =
3731 batchVariantMatmulOp.getIndexingMapsArray();
3732 SmallVector<AffineMap, 3> defaultIndexingMaps =
3733 batchVariantMatmulOp.getDefaultIndexingMaps(
3734 batchVariantMatmulOp->getContext());
3735
3736 if (opIndexingMaps.size() != 3)
3737 return batchVariantMatmulOp->emitOpError()
3738 << "Indexing_map attribute must have 3 affine maps.";
3739
3740 auto opIndexingMap = opIndexingMaps[opIndex];
3741 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3742
3743 if (opIndex == 2 &&
3744 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3745 return failure();
3746
3747 if (opIndex != 2 &&
3748 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3749 defaultIndexingMap, opIndex == 0)))
3750 return failure();
3751
3752 return success();
3753}
3754
3755namespace mlir {
3756namespace linalg {
3757
3758std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3759 if (m == 2 && r == 3)
3760 return WinogradConv2DFmr::F_2_3;
3761 if (m == 4 && r == 3)
3762 return WinogradConv2DFmr::F_4_3;
3763 if (m == 2 && r == 5)
3764 return WinogradConv2DFmr::F_2_5;
3765 return std::nullopt;
3766}
3767
3768std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3769 switch (fmr) {
3770 case WinogradConv2DFmr::F_2_3:
3771 return {2, 3};
3772 case WinogradConv2DFmr::F_4_3:
3773 return {4, 3};
3774 case WinogradConv2DFmr::F_2_5:
3775 return {2, 5};
3776 }
3777 llvm_unreachable("Unkown WinogradConv2DFmr");
3778}
3779
3780//===----------------------------------------------------------------------===//
3781// MatMulOp
3782//===----------------------------------------------------------------------===//
3783
3784static FailureOr<SmallVector<SmallVector<int64_t>>>
3787 for (auto map : maps) {
3788 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3789 if (!attr)
3790 return failure();
3792 for (auto result : attr.getAffineMap().getResults()) {
3793 auto dim = dyn_cast<AffineDimExpr>(result);
3794 if (!dim)
3795 return failure();
3796 pos.push_back(dim.getPosition());
3797 }
3798 positions.push_back(pos);
3799 }
3800 return positions;
3801}
3802
3803/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3804SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3805 AffineExpr d0, d1, d2;
3806 SmallVector<AffineMap> indexingMaps;
3807 bindDims(context, d0, d1, d2);
3808 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3809 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3810 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3811 return indexingMaps;
3812}
3813
3814bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3815 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3816 if (!maps)
3817 return false;
3818 if (maps.size() != 3)
3819 return false;
3820 auto positions = getAffineResultPositions(maps);
3821 if (failed(positions))
3822 return false;
3823 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3824 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3825 (*positions)[2] == SmallVector<int64_t>{0, 1};
3826}
3827
3828SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3829 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3830 utils::IteratorType::parallel,
3831 utils::IteratorType::reduction};
3832}
3833
3834unsigned MatmulOp::getNumRegionArgs() { return 3; }
3835
3836std::string MatmulOp::getLibraryCallName() {
3837 return generateLibraryCallName(getOperation());
3838}
3839
3840bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3841
3842/// Check if the op has broadcast and/or transpose semantic. Returns true if
3843/// the user defined indexing maps are not equal to default map.
3844bool MatmulOp::hasUserDefinedMaps() {
3845 SmallVector<AffineMap, 3> defaultMaps =
3846 getDefaultIndexingMaps(this->getContext());
3847 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3848 return defaultMaps != explicitMaps;
3849}
3850
3851/// Implements the block region builder for the MatmulOp. This is called by
3852/// 'fillStructuredOpRegion'.
3853void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3854 ArrayRef<NamedAttribute> attrs,
3855 function_ref<InFlightDiagnostic()> emitError) {
3856 if (emitError && block.getNumArguments() != 3) {
3857 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3858 << block.getNumArguments();
3859 return;
3860 }
3861 assert(block.getNumArguments() == 3 &&
3862 "MatmulOp regionBuilder expects 3 args");
3863 RegionBuilderHelper helper(b, block);
3864 SmallVector<Value> yields;
3865
3866 TypeFn castVal = TypeFn::cast_signed;
3867 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3868 return attr.getName() == "cast";
3869 });
3870 if (castIter != attrs.end()) {
3871 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3872 castVal = attr.getValue();
3873 }
3874
3875 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3876 block.getArgument(0));
3877 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3878 block.getArgument(1));
3879 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3880 if (!value3)
3881 return;
3882 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3883 value3, emitError);
3884 if (!value4)
3885 return;
3886 yields.push_back(value4);
3887 helper.yieldOutputs(yields);
3888}
3889
3890/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3891/// broadcast map must include K dimension.
3892/// TODO: Strict inclusion of K dimension in the broadcast map is not
3893/// necessary for both input matrices simultaneously. We can relax this
3894/// condition to have K dimension for one input matrix map and infer the K
3895/// dimension for other input matrix map from the one already having K
3896/// dimension.
3897bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3898 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3899 AffineExpr expr = bcastMap.getResult(0);
3900 // Invalid map if the common dimension of matmul not found.
3901 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3902}
3903
3904static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3905 if (parser.parseOptionalKeyword("indexing_maps"))
3906 return ArrayAttr{
3907 nullptr}; // Success in case indexing_maps was not provided.
3908
3909 ArrayAttr arrayAttr;
3910 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3911 return failure();
3912
3913 if (llvm::any_of(arrayAttr,
3914 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3915 return parser.emitError(parser.getCurrentLocation())
3916 << "element of indexing_maps array is not an affine_map";
3917
3918 return arrayAttr;
3919}
3920
3921ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3922 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3923 if (failed(indexingMapsAttr))
3924 return failure();
3925
3926 if (*indexingMapsAttr == nullptr) {
3927 auto indexingMapAttrs = llvm::map_to_vector(
3928 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3929 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3930 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3931 }
3932
3933 result.addAttribute("indexing_maps", *indexingMapsAttr);
3934 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3935 MatmulOp::getRegionBuilder());
3936}
3937
3938void MatmulOp::print(OpAsmPrinter &p) {
3939 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3940 MatmulOp::getDefaultIndexingMaps(getContext()),
3941 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3942 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3943 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3944
3945 std::array<StringRef, 3> elidedAttrs = {
3946 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3947 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3948 elidedAttrs);
3949}
3950
3951/// Verify the user defined indexing maps.
3952LogicalResult MatmulOp::verify() {
3953 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3954 if (!hasUserDefinedMaps())
3955 return success();
3956
3957 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3958 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3959 return failure();
3960 }
3961 return success();
3962}
3963
3964LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3965 return memref::foldMemRefCast(*this);
3966}
3967
3968void MatmulOp::getEffects(
3969 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3970 &effects) {
3971 if (hasPureTensorSemantics())
3972 return;
3973 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3974}
3975
3976Speculation::Speculatability MatmulOp::getSpeculatability() {
3977 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3978}
3979
3980SmallVector<AffineMap>
3981MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3982 AffineExpr d0, d1, d2;
3983 MLIRContext *context = builder.getContext();
3984 bindDims(context, d0, d1, d2);
3985 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3986 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3987 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3988 return {mapLHS, mapRHS, mapOut};
3989}
3990
3992 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3993 if (!maps)
3994 return false;
3995 if (maps.size() != 3)
3996 return false;
3997 auto positions = getAffineResultPositions(maps);
3998 if (failed(positions))
3999 return false;
4000 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
4001 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4002 (*positions)[2] == SmallVector<int64_t>{0, 1};
4003}
4004
4007 ValueRange inputs, ValueRange outputs,
4008 ArrayRef<NamedAttribute> attributes) {
4009 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4010 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4011}
4012
4015 ValueRange inputs, ValueRange outputs,
4016 ArrayRef<NamedAttribute> attributes) {
4017 OperationState state(location, getOperationName());
4018 build(builder, state, inputs, outputs, attributes);
4019 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4020 assert(res && "builder didn't return the right type");
4021 return res;
4022}
4023
4026 TypeRange resultTensorTypes,
4027 ValueRange inputs, ValueRange outputs,
4028 ArrayRef<NamedAttribute> attributes) {
4029 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4030 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4031}
4032
4035 TypeRange resultTensorTypes, ValueRange inputs,
4036 ValueRange outputs,
4037 ArrayRef<NamedAttribute> attributes) {
4038 OperationState state(location, getOperationName());
4039 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4040 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4041 assert(res && "builder didn't return the right type");
4042 return res;
4043}
4044
4047 TypeRange resultTensorTypes,
4048 ValueRange inputs, ValueRange outputs,
4049 Attribute cast,
4050 ArrayRef<NamedAttribute> attributes) {
4051 result.addAttribute("cast", cast);
4052 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4053 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4054}
4055
4058 TypeRange resultTensorTypes, ValueRange inputs,
4059 ValueRange outputs, Attribute cast,
4060 ArrayRef<NamedAttribute> attributes) {
4061 OperationState state(location, getOperationName());
4062 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4063 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4064 assert(res && "builder didn't return the right type");
4065 return res;
4066}
4067
4069 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4071 op->getAttr("indexing_maps"));
4072}
4073
4075MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4076 AffineExpr d0, d1, d2;
4077 MLIRContext *context = builder.getContext();
4078 bindDims(context, d0, d1, d2);
4079 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4080 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4081 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4082 return {mapLHS, mapRHS, mapOut};
4083}
4084
4086 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4087 if (!maps)
4088 return false;
4089 if (maps.size() != 3)
4090 return false;
4091 auto positions = getAffineResultPositions(maps);
4092 if (failed(positions))
4093 return false;
4094 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4095 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4096 (*positions)[2] == SmallVector<int64_t>{0, 1};
4097}
4098
4101 ValueRange inputs, ValueRange outputs,
4102 ArrayRef<NamedAttribute> attributes) {
4103 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4104 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4105}
4106
4109 ValueRange inputs, ValueRange outputs,
4110 ArrayRef<NamedAttribute> attributes) {
4111 OperationState state(location, getOperationName());
4112 build(builder, state, inputs, outputs, attributes);
4113 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4114 assert(res && "builder didn't return the right type");
4115 return res;
4116}
4117
4120 TypeRange resultTensorTypes,
4121 ValueRange inputs, ValueRange outputs,
4122 ArrayRef<NamedAttribute> attributes) {
4123 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4124 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4125}
4126
4129 TypeRange resultTensorTypes, ValueRange inputs,
4130 ValueRange outputs,
4131 ArrayRef<NamedAttribute> attributes) {
4132 OperationState state(location, getOperationName());
4133 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4134 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4135 assert(res && "builder didn't return the right type");
4136 return res;
4137}
4138
4141 TypeRange resultTensorTypes,
4142 ValueRange inputs, ValueRange outputs,
4143 Attribute cast,
4144 ArrayRef<NamedAttribute> attributes) {
4145 result.addAttribute("cast", cast);
4146 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4147 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4148}
4149
4152 TypeRange resultTensorTypes, ValueRange inputs,
4153 ValueRange outputs, Attribute cast,
4154 ArrayRef<NamedAttribute> attributes) {
4155 OperationState state(location, getOperationName());
4156 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4157 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4158 assert(res && "builder didn't return the right type");
4159 return res;
4160}
4161
4163 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4165 op->getAttr("indexing_maps"));
4166}
4167
4169BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4170 AffineExpr d0, d1, d2, d3;
4171 MLIRContext *context = builder.getContext();
4172 bindDims(context, d0, d1, d2, d3);
4173 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4174 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4175 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4176 return {mapLHS, mapRHS, mapOut};
4177}
4178
4180 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4181 if (!maps)
4182 return false;
4183 if (maps.size() != 3)
4184 return false;
4185 auto positions = getAffineResultPositions(maps);
4186 if (failed(positions))
4187 return false;
4188 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4189 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4190 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4191}
4192
4194 OpBuilder &builder, OperationState &result, ValueRange inputs,
4195 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4196 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4197 BatchMatmulOp::getRegionBuilder(),
4198 getDefaultIndexingMaps(builder));
4199}
4200
4203 ValueRange inputs, ValueRange outputs,
4204 ArrayRef<NamedAttribute> attributes) {
4205 OperationState state(location, getOperationName());
4206 build(builder, state, inputs, outputs, attributes);
4207 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4208 assert(res && "builder didn't return the right type");
4209 return res;
4210}
4211
4213 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4214 ValueRange inputs, ValueRange outputs,
4215 ArrayRef<NamedAttribute> attributes) {
4216 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4217 BatchMatmulOp::getRegionBuilder(),
4218 getDefaultIndexingMaps(builder));
4219}
4220
4223 TypeRange resultTensorTypes, ValueRange inputs,
4224 ValueRange outputs,
4225 ArrayRef<NamedAttribute> attributes) {
4226 OperationState state(location, getOperationName());
4227 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4228 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4229 assert(res && "builder didn't return the right type");
4230 return res;
4231}
4232
4234 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4235 ValueRange inputs, ValueRange outputs, Attribute cast,
4236 ArrayRef<NamedAttribute> attributes) {
4237 result.addAttribute("cast", cast);
4238 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4239 BatchMatmulOp::getRegionBuilder(),
4240 getDefaultIndexingMaps(builder));
4241}
4242
4245 TypeRange resultTensorTypes, ValueRange inputs,
4246 ValueRange outputs, Attribute cast,
4247 ArrayRef<NamedAttribute> attributes) {
4248 OperationState state(location, getOperationName());
4249 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4250 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4251 assert(res && "builder didn't return the right type");
4252 return res;
4253}
4254
4256 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4258 op->getAttr("indexing_maps"));
4259}
4260
4262BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4263 AffineExpr d0, d1, d2, d3;
4264 MLIRContext *context = builder.getContext();
4265 bindDims(context, d0, d1, d2, d3);
4266 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4267 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4268 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4269 return {mapLHS, mapRHS, mapOut};
4270}
4271
4273 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4274 if (!maps)
4275 return false;
4276 if (maps.size() != 3)
4277 return false;
4278 auto positions = getAffineResultPositions(maps);
4279 if (failed(positions))
4280 return false;
4281 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4282 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4283 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4284}
4285
4287 OpBuilder &builder, OperationState &result, ValueRange inputs,
4288 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4289 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4290 BatchMatmulOp::getRegionBuilder(),
4291 getDefaultIndexingMaps(builder));
4292}
4293
4296 ValueRange inputs, ValueRange outputs,
4297 ArrayRef<NamedAttribute> attributes) {
4298 OperationState state(location, getOperationName());
4299 build(builder, state, inputs, outputs, attributes);
4300 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4301 assert(res && "builder didn't return the right type");
4302 return res;
4303}
4304
4306 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4307 ValueRange inputs, ValueRange outputs,
4308 ArrayRef<NamedAttribute> attributes) {
4309 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4310 BatchMatmulOp::getRegionBuilder(),
4311 getDefaultIndexingMaps(builder));
4312}
4313
4316 TypeRange resultTensorTypes, ValueRange inputs,
4317 ValueRange outputs,
4318 ArrayRef<NamedAttribute> attributes) {
4319 OperationState state(location, getOperationName());
4320 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4321 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4322 assert(res && "builder didn't return the right type");
4323 return res;
4324}
4325
4327 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4328 ValueRange inputs, ValueRange outputs, Attribute cast,
4329 ArrayRef<NamedAttribute> attributes) {
4330 result.addAttribute("cast", cast);
4331 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4332 BatchMatmulOp::getRegionBuilder(),
4333 getDefaultIndexingMaps(builder));
4334}
4335
4338 TypeRange resultTensorTypes, ValueRange inputs,
4339 ValueRange outputs, Attribute cast,
4340 ArrayRef<NamedAttribute> attributes) {
4341 OperationState state(location, getOperationName());
4342 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4343 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4344 assert(res && "builder didn't return the right type");
4345 return res;
4346}
4347
4349 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4351 op->getAttr("indexing_maps"));
4352}
4353
4354//===----------------------------------------------------------------------===//
4355// ContractOp
4356//===----------------------------------------------------------------------===//
4357
4358SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4359 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4360 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4361 // domains are all the same, and each implements a projected permutation.
4362 // Each iteration space dim must occur for at least one operand and either
4363 // takes part in a contraction/reduction or else has parallel iteration type.
4364 // We have that a dim is a contraction/reduction dim if and only if the dim
4365 // occurs for the output operand. We use this fact for fast inference:
4366 // NB: In case we allow dims to occur solely for one input, the above still
4367 // holds: per the einsum semantics, these are reduction dims as well.
4368 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4369 for (auto result : outAffineMap.getResults()) {
4370 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4371 assert(dimExpr && "affine_map is a projected permutation");
4372 dimsInOutput[dimExpr.getPosition()] = true;
4373 }
4374
4376 for (auto dimOccursInOutput : dimsInOutput)
4377 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4378 : utils::IteratorType::reduction);
4379
4380 return iteratorTypes;
4381}
4382
4383unsigned ContractOp::getNumRegionArgs() { return 3; }
4384
4385/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4386void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4387 ArrayRef<NamedAttribute> attrs,
4388 function_ref<InFlightDiagnostic()> emitError) {
4389 if (emitError && block.getNumArguments() != 3) {
4390 emitError() << "ContractOp regionBuilder expects 3 args, got "
4391 << block.getNumArguments();
4392 return;
4393 }
4394 assert(block.getNumArguments() == 3 &&
4395 "ContractOp regionBuilder expects 3 args");
4396 RegionBuilderHelper helper(b, block);
4397
4398 TypeFn castSignedness = TypeFn::cast_signed;
4399 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4400 return attr.getName() == "cast";
4401 });
4402 if (castIter != attrs.end()) {
4403 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4404 castSignedness = attr.getValue();
4405 }
4406
4407 // TODO: Support fields with operators besides mult & add.
4408 Type outType = block.getArgument(2).getType();
4409 Value lhsAtOutType =
4410 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4411 Value rhsAtOutType =
4412 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4413 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4414 rhsAtOutType, emitError);
4415 if (!productAtOutType)
4416 return;
4417 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4418 productAtOutType, emitError);
4419 if (!result)
4420 return;
4421 helper.yieldOutputs({result});
4422}
4423
4424ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4425 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4426 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4427 return parser.emitError(parser.getCurrentLocation(),
4428 "expected 'indexing_maps' attribute");
4429 result.addAttribute("indexing_maps", *indexingMapsAttr);
4430
4431 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4432 regionBuilder);
4433}
4434
4435void ContractOp::print(OpAsmPrinter &p) {
4436 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4438 p, getOperation(), getInputs(), getOutputs(),
4439 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4440}
4441
4442LogicalResult ContractOp::verify() {
4443 int iterationSpaceDims = -1;
4444 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4445 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4446 // access an input operand (so occurrence count can be at most 2) and
4447 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4448 SmallVector<size_t> inOccurrences;
4449 SmallVector<size_t> outOccurrences;
4450
4451 // A helper so that for each operand's affine_map and type we check that ...
4452 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4453 bool isInput) -> LogicalResult {
4454 // ... the affine_map is a projected permutation;
4455 if (!affineMap.isProjectedPermutation())
4456 return emitError("provided affine_map is not a projected permutation");
4457
4458 // ... the rank of the affine_map's results and corresponding type match;
4459 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4460 if (affineMap.getNumResults() != shapedType.getRank())
4461 return emitError("ranks of shaped operand and results of corresponding "
4462 "affine_map differ");
4463 } else if (affineMap.getNumResults() != 0) {
4464 return emitError("affine_map specifies shaped access while operand has "
4465 "non-shaped type");
4466 }
4467
4468 // ... the rank of the affine_map's domain is the same as those seen prior;
4469 if (iterationSpaceDims == -1) {
4470 iterationSpaceDims = affineMap.getNumDims();
4471 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4472 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4473 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4474 return emitError("iteration spaces of provided affine_maps differ");
4475 }
4476
4477 // ... update counts of dims used to access either an input or the output.
4478 for (AffineExpr affineExpr : affineMap.getResults()) {
4479 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4480 if (!affineDimExpr)
4481 llvm_unreachable("affine_map is a projected permutation");
4482
4483 if (isInput)
4484 inOccurrences[affineDimExpr.getPosition()] += 1;
4485 else
4486 outOccurrences[affineDimExpr.getPosition()] += 1;
4487 }
4488
4489 return success();
4490 };
4491
4492 for (auto &&[affineMap, operandType, isInput] :
4493 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4494 SmallVector<bool>{true, true, false})) {
4495 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4496 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4497 }
4498
4499 bool hasContractingDim = false;
4500 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4501 size_t inOccCount = inOccurrences[dimIndex];
4502 size_t outOccCount = outOccurrences[dimIndex];
4503
4504 // We have a contracting dim if and only if ...
4505 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4506
4507 if (inOccCount == 0 && outOccCount == 0)
4508 return emitError() << "iteration space dim at index " << dimIndex
4509 << " not used to access any operand";
4510
4511 // NB: We disallow a dim which occurs for only one input operand and not
4512 // for the output. In terms of einsum semantics such dims have a
4513 // sensible meaning - namely an additional reduction per each such dim.
4514 // By contrast, the ContractionOpInterface does not know about this
4515 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4516 // while vector.contract's verifier accepts dims of this kind many of
4517 // its lowerings give up on encountering these dims.
4518 // TODO: Remove following once we have comprehensive support for input-only
4519 // reduction dims, at both the linalg- and vector-dialect levels.
4520 if (inOccCount == 1 && outOccCount != 1)
4521 return emitError()
4522 << "iteration space dim at index " << dimIndex
4523 << " is neither a contracting dim nor of parallel iteration type";
4524 }
4525
4526 if (!hasContractingDim)
4527 return emitError("'indexing_maps' do not specify a contracting dimension");
4528
4529 return success();
4530}
4531
4532LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4533 return memref::foldMemRefCast(*this);
4534}
4535
4536void ContractOp::getEffects(
4537 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4538 &effects) {
4539 if (hasPureTensorSemantics())
4540 return;
4541 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4542}
4543
4544Speculation::Speculatability ContractOp::getSpeculatability() {
4545 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4546}
4547
4548//===----------------------------------------------------------------------===//
4549// Implementation of BatchMatmulOp
4550//===----------------------------------------------------------------------===//
4551SmallVector<AffineMap>
4552BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4553 AffineExpr d0, d1, d2, d3;
4554 SmallVector<AffineMap> indexingMaps;
4555 bindDims(context, d0, d1, d2, d3);
4556 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4557 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4558 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4559 return indexingMaps;
4560}
4561
4562bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4563 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4564 if (!maps)
4565 return false;
4566 if (maps.size() != 3)
4567 return false;
4568 auto positions = getAffineResultPositions(maps);
4569 if (failed(positions))
4570 return false;
4571 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4572 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4573 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4574}
4575
4576SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4577 return SmallVector<utils::IteratorType>{
4578 utils::IteratorType::parallel, utils::IteratorType::parallel,
4579 utils::IteratorType::parallel, utils::IteratorType::reduction};
4580}
4581
4582unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4583
4584std::string BatchMatmulOp::getLibraryCallName() {
4585 return generateLibraryCallName(getOperation());
4586}
4587
4588/// Check if the op has broadcast and/or transpose semantic. Returns true if
4589/// the user defined indexing maps are not equal to default map.
4590bool BatchMatmulOp::hasUserDefinedMaps() {
4591 SmallVector<AffineMap, 3> defaultMaps =
4592 getDefaultIndexingMaps(this->getContext());
4593 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4594 return defaultMaps != explicitMaps;
4595}
4596
4597/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4598/// broadcast map must include K dimension.
4599/// TODO: Strict inclusion of K dimension in the broadcast map is not
4600/// necessary for both input matrices simultaneously. We can relax this
4601/// condition to have K dimension for one input matrix map and infer the K
4602/// dimension for other input matrix map from the one already having K
4603/// dimension.
4604bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4605 assert(bcastMap.getNumResults() < 3 &&
4606 "Expected less than 3 result dim expr.");
4607 bool isValid = false;
4608 enum Indices { batchPos, mPos, nPos, kPos };
4609 if (bcastMap.getNumResults() == 1) {
4610 AffineExpr expr = bcastMap.getResult(0);
4611 isValid = expr.isFunctionOfDim(kPos);
4612 } else if (bcastMap.getNumResults() == 2) {
4613 AffineExpr expr0 = bcastMap.getResult(0);
4614 AffineExpr expr1 = bcastMap.getResult(1);
4615 isValid =
4616 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4617 expr0.isFunctionOfDim(mPos)) &&
4618 expr1.isFunctionOfDim(kPos))
4619 : ((expr0.isFunctionOfDim(batchPos) &&
4620 expr1.isFunctionOfDim(kPos)) ||
4621 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4622 }
4623 return isValid;
4624}
4625
4626void BatchMatmulOp::regionBuilder(
4627 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4628 function_ref<InFlightDiagnostic()> emitError) {
4629 if (emitError && block.getNumArguments() != 3) {
4630 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4631 << block.getNumArguments();
4632 return;
4633 }
4634 assert(block.getNumArguments() == 3 &&
4635 "BatchMatmulOp regionBuilder expects 3 args");
4636 RegionBuilderHelper helper(b, block);
4637 SmallVector<Value> yields;
4638
4639 TypeFn castVal = TypeFn::cast_signed;
4640 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4641 return attr.getName() == "cast";
4642 });
4643 if (castIter != attrs.end()) {
4644 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4645 castVal = attr.getValue();
4646 }
4647
4648 auto toType = block.getArgument(2).getType();
4649 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4650 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4651 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4652 Value addVal =
4653 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4654 yields.push_back(addVal);
4655 helper.yieldOutputs(yields);
4656}
4657
4658ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4659 SmallVector<Attribute, 3> indexingMapsAttr;
4660 Attribute mapAttr;
4661 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4662 if (parser.parseEqual())
4663 return failure();
4664
4665 if (parser.parseLSquare())
4666 return failure();
4667
4668 do {
4669 if (parser.parseAttribute(mapAttr))
4670 return failure();
4671 if (!isa<AffineMapAttr>(mapAttr)) {
4672 return parser.emitError(parser.getCurrentLocation(),
4673 "expected affine map attribute");
4674 }
4675 indexingMapsAttr.push_back(mapAttr);
4676
4677 if (parser.parseOptionalComma())
4678 break;
4679 } while (true);
4680
4681 if (parser.parseRSquare())
4682 return failure();
4683 }
4684 // Initialize indexingMaps, if not supplied explicitly.
4685 if (indexingMapsAttr.empty()) {
4686 indexingMapsAttr = llvm::map_to_vector(
4687 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4688 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4689 }
4690 result.addAttribute("indexing_maps",
4691 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4692
4693 return ::parseNamedStructuredOp(parser, result,
4694 BatchMatmulOp::getNumRegionArgs(),
4695 BatchMatmulOp::getRegionBuilder());
4696}
4697
4698void BatchMatmulOp::print(OpAsmPrinter &p) {
4699 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4700 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4701 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4702 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4703 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4704
4705 std::array<StringRef, 3> elidedAttrs = {
4706 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4707 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4708 elidedAttrs);
4709}
4710
4711/// Verify the user defined indexing maps.
4712LogicalResult BatchMatmulOp::verify() {
4713 // Verification of pure batch_matmul is handled by
4714 // verifyStructuredOpInterface().
4715 if (!hasUserDefinedMaps())
4716 return success();
4717
4718 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4720 return failure();
4721 }
4722 return success();
4723}
4724
4725LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4726 SmallVectorImpl<OpFoldResult> &) {
4727 return memref::foldMemRefCast(*this);
4728}
4729
4730void BatchMatmulOp::getEffects(
4731 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4732 &effects) {
4733 if (hasPureTensorSemantics())
4734 return;
4735 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4736}
4737
4738Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4739 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4740}
4741
4742//===----------------------------------------------------------------------===//
4743// ElementwiseOp
4744//===----------------------------------------------------------------------===//
4745//
4746namespace {
4747struct ArityGroupAndKind {
4748 // The enum class {Unary, Binary, Ternary, ..}
4749 ElementwiseArityGroup arityGroup;
4750
4751 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4752 union Kind {
4753 UnaryFn unaryFn;
4754 BinaryFn binaryFn;
4755 TernaryFn ternaryFn;
4756 } kind;
4757};
4758
4759unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4760 return static_cast<unsigned>(arityGroup);
4761}
4762} // namespace
4763
4764static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4765 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4766 constexpr int lastBinary =
4767 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4768 constexpr int lastTernary =
4769 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4770
4771 int val = static_cast<int>(kind);
4772 ArityGroupAndKind result;
4773
4774 if (val < lastUnary) {
4775 result.arityGroup = ElementwiseArityGroup::Unary;
4776 result.kind.unaryFn = static_cast<UnaryFn>(val);
4777 return result;
4778 }
4779 if (val < lastBinary) {
4780 result.arityGroup = ElementwiseArityGroup::Binary;
4781 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4782 return result;
4783 }
4784 if (val >= lastTernary) {
4785 llvm_unreachable("unhandled ElementwiseFn");
4786 }
4787 result.arityGroup = ElementwiseArityGroup::Ternary;
4788 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4789 return result;
4790}
4791
4792SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4793 auto rank = getResultRank();
4794 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4795}
4796
4798ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4799 MLIRContext *context) {
4800 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4801 return SmallVector<AffineMap>(numMaps, map);
4802}
4803
4804ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4805 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4806 Attribute attr;
4807 mlir::linalg::ElementwiseKind elemwiseKindVal;
4808 if (parser.parseKeyword("kind") || parser.parseEqual())
4809 return failure();
4810
4811 if (succeeded(parser.parseAttribute(attr))) {
4812 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4813 if (!elemwiseKindAttr)
4814 return parser.emitError(parser.getCurrentLocation(),
4815 "expected ElementwiseKind attribute");
4816 elemwiseKindVal = elemwiseKindAttr.getValue();
4817 } else {
4818 return parser.emitError(parser.getCurrentLocation(),
4819 "expected operation 'kind' attribute");
4820 }
4821 result.addAttribute(
4822 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4823
4824 // Parse optional `indexing_maps`
4825 SmallVector<Attribute, 3> indexingMapsAttr;
4826 Attribute mapAttr;
4827 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4828 if (parser.parseEqual())
4829 return failure();
4830 if (parser.parseLSquare())
4831 return failure();
4832 do {
4833 if (parser.parseAttribute(mapAttr))
4834 return failure();
4835 if (!isa<AffineMapAttr>(mapAttr))
4836 return parser.emitError(parser.getCurrentLocation(),
4837 "expected affine map attribute");
4838 indexingMapsAttr.push_back(mapAttr);
4839 if (parser.parseOptionalComma())
4840 break;
4841 } while (true);
4842 if (parser.parseRSquare())
4843 return failure();
4844 }
4845 // At this stage of parsing the only way to infer number of region
4846 // args is through op kind, as input output tensors are not parsed yet.
4847 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4848 int numRegionArgs =
4849 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4850 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4851 ElementwiseOp::getRegionBuilder())) {
4852 return parser.emitError(parser.getCurrentLocation(),
4853 "unable to parse elemwise op");
4854 }
4855
4856 // Initialize indexingMaps, if not supplied explicitly.
4857 if (indexingMapsAttr.empty()) {
4858 // We need to infer the numDims of the indexing maps from the output
4859 // type which is already parsed by now.
4860 auto resultType = result.operands[result.operands.size() - 1].getType();
4861 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4862 if (!shapedType)
4863 return parser.emitError(parser.getCurrentLocation(),
4864 "return type needs to be shaped type");
4865 auto numDims = shapedType.getRank();
4866 indexingMapsAttr = llvm::map_to_vector(
4867 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4868 parser.getContext()),
4869 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4870 }
4871
4872 result.addAttribute("indexing_maps",
4873 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4874 return success();
4875}
4876
4877void ElementwiseOp::print(OpAsmPrinter &p) {
4878 p << " kind=";
4879 p.printAttribute(getKindAttr());
4880 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4881 "indexing_maps"};
4882 unsigned arity =
4883 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4884 unsigned numDims = getResultRank();
4885
4886 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4887 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4888 getContext()),
4889 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4890
4891 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4892 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4893
4894 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4895 elidedAttrs);
4896}
4897
4898/// Implements the block region builder for the ElementwiseOp. This is called by
4899/// 'fillStructuredOpRegion'.
4900void ElementwiseOp::regionBuilder(
4901 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4902 function_ref<InFlightDiagnostic()> emitError) {
4903 ElementwiseKind elemwiseKind;
4904 for (auto attr : attrs) {
4905 if (attr.getName() == b.getStringAttr("kind")) {
4906 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4907 assert(kindAttr && "op kind attribute incorrectly set");
4908 elemwiseKind = kindAttr.getValue();
4909 break;
4910 }
4911 }
4912
4913 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4914 auto arityGroup = groupAndKind.arityGroup;
4915 auto kind = groupAndKind.kind;
4916 if (emitError && block.getNumArguments() !=
4917 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4918 emitError() << "Elementwise regionBuilder expects "
4919 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4920 << block.getNumArguments();
4921 return;
4922 }
4923 assert(block.getNumArguments() ==
4924 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4925 && "Elementwise regionBuilder number of block args mismatch");
4926
4927 RegionBuilderHelper helper(b, block);
4928 SmallVector<Value> yields;
4929 Value result;
4930
4931 if (arityGroup == ElementwiseArityGroup::Unary) {
4932 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4933
4934 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4935 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4936 block.getArgument(1));
4937
4938 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4939 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4940 block.getArgument(1), block.getArgument(2));
4941
4942 } else {
4943 assert(false && "found unhandled category in elemwise");
4944 }
4945
4946 yields.push_back(result);
4947 helper.yieldOutputs(yields);
4948}
4949
4950LogicalResult ElementwiseOp::fold(FoldAdaptor,
4951 SmallVectorImpl<OpFoldResult> &) {
4952 return memref::foldMemRefCast(*this);
4953}
4954
4955void ElementwiseOp::getEffects(
4956 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4957 &effects) {
4958 if (hasPureTensorSemantics())
4959 return;
4960 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4961}
4962
4963Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4964 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4965}
4966
4967//===----------------------------------------------------------------------===//
4968// PackOp/UnPackOp Common
4969//===----------------------------------------------------------------------===//
4970
4971template <typename OpTy, typename>
4972SmallVector<int64_t>
4974 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4975 ? packOrUnPack.getDestType()
4976 : packOrUnPack.getSourceType();
4977 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4978 ? packOrUnPack.getSourceType()
4979 : packOrUnPack.getDestType();
4981 packedType.getShape().take_front(unpackedType.getRank()));
4982 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4984 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4985 }
4986 return result;
4987}
4992
4993// Given the (potentially) updated packed type, `newPackedTy`, generates an
4994// updated mixed-tile-sizes attribute. A tile size is updated only
4995// when:
4996// * a dim from newPackedTy is static, and
4997// * the corresponding size from mixedTiles is still dynamic.
4998// Otherwise, the original tile size is preserved.
4999// Note - packed-type-dim and mixed-tile-size should always match!
5002 SmallVector<OpFoldResult> mixedTiles) {
5003 SmallVector<OpFoldResult> newMixedTileSizes;
5004 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5005 .getShape()
5006 .take_back(mixedTiles.size()),
5007 mixedTiles)) {
5008 int64_t shape = std::get<0>(it);
5009 if (shape == ShapedType::kDynamic) {
5010 newMixedTileSizes.push_back(std::get<1>(it));
5011 continue;
5012 }
5013
5014 // If the current result dim is static, update the dynamic mixed-size
5015 // (provided the original value is dynamic).
5016 OpFoldResult tile = std::get<1>(it);
5017 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5018 // Already a constant
5019 newMixedTileSizes.push_back(tile);
5020 } else {
5021 assert(getConstantIntValue(tile).value() == shape &&
5022 "tile size and dim size don't match!");
5023 newMixedTileSizes.push_back(
5024 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5025 }
5026 }
5027
5028 return newMixedTileSizes;
5029}
5030
5031template <typename OpTy>
5032static LogicalResult
5034 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5035 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5036 "applies to only pack or unpack operations");
5037 int64_t destRank = op.getDestRank();
5038 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5039 reifiedReturnShapes[0] =
5040 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
5041 return success();
5042}
5043
5044template <typename OpTy>
5046 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5047 "applies to only pack or unpack operations");
5048 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5049 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5050 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5051 assert(tiles.size() == dimsToTile.size() &&
5052 "tiles must match indices of dimension to block");
5053 // bind the dimension `i` with the tile factor.
5054 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5055 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5056 return dimAndTileMapping;
5057}
5058
5059template <typename OpTy>
5061 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5062 "applies to only pack or unpack operations");
5063 Builder builder(op);
5064 SmallVector<OpFoldResult> mixedInnerTiles;
5065 unsigned dynamicValIndex = 0;
5066 for (int64_t staticTile : op.getStaticInnerTiles()) {
5067 if (ShapedType::isStatic(staticTile))
5068 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5069 else
5070 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5071 }
5072 return mixedInnerTiles;
5073}
5074
5075template <typename OpTy>
5077 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5078 "applies to only pack or unpack operations");
5079 SmallVector<Value> dynamicTiles;
5080 SmallVector<int64_t> staticTiles;
5081 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5082 return staticTiles;
5083}
5084
5085/// Returns true if `dimsPos` is invalid. It is invalid when:
5086/// a) It contains duplicate.
5087/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5088/// c) The number of elements in `dimsPos` is > than `rank`.
5090 size_t rank) {
5091 size_t dimsPosSize = dimsPos.size();
5092 if (dimsPosSize > rank)
5093 return true;
5094 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5095 if (dimsPosSize != uniqued.size())
5096 return true;
5097 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5098 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5099 });
5100}
5101
5102template <typename OpTy>
5103static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5104 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5105 "applies to only pack or unpack operations");
5106 Operation *op = packOrUnPack.getOperation();
5107
5108 // Return true if we have a zero-value tile.
5109 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5110 return llvm::any_of(tiles, isZeroInteger);
5111 };
5112
5113 // Verify that the source and destination are ranked types.
5114 if (!packOrUnPack.getSourceType().hasRank() ||
5115 !packOrUnPack.getDestType().hasRank())
5116 return op->emitError("expected both source and destination to have rank");
5117
5118 // Verify that the Operation does not have mixed tensor/buffer semantics.
5119 if (!packOrUnPack.hasPureBufferSemantics() &&
5120 !packOrUnPack.hasPureTensorSemantics())
5121 return op->emitError("mixing tensor and buffer semantics is not allowed");
5122 const unsigned numResults = packOrUnPack.getNumResults();
5123 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5124 return op->emitError("expected 1 result, got ") << numResults;
5125 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5126 return op->emitError("expected 0 results, got ") << numResults;
5127
5128 // Verify tiles. Do not allow zero tiles.
5129 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5130 if (hasZeros(mixedTiles))
5131 return op->emitError("invalid zero tile factor");
5132
5133 // Verify inner_dims_pos and outer_dims_perm.
5134 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5135 ? packOrUnPack.getSourceType()
5136 : packOrUnPack.getDestType();
5137 size_t unpackedRank = unpackedType.getRank();
5138 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5139 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5140 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5141 return op->emitError("invalid inner_dims_pos vector");
5142 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5143 return op->emitError("invalid outer_dims_perm vector");
5144 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5145 return op->emitError("outer_dims_perm must be a permutation or empty");
5146
5147 // Tiling factors must be less than or equal to the input rank for pack (or
5148 // output rank for unpack), and must match the number of `inner_dims_pos`.
5149 if (mixedTiles.size() > unpackedRank) {
5150 return op->emitError("tiling factors must be less than or equal to the "
5151 "input rank for pack or output rank for unpack");
5152 }
5153 if (mixedTiles.size() != innerDimsPos.size()) {
5154 return op->emitError(
5155 "tiling factors must equal the number of dimensions to tile");
5156 }
5157
5158 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5159 ? packOrUnPack.getDestType()
5160 : packOrUnPack.getSourceType();
5161 size_t packedRank = packedType.getRank();
5162 // Require output rank to match input rank + number of blocking factors.
5163 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5164 if (expectedPackedRank != packedRank) {
5165 return op->emitError(
5166 "packed rank != (unpacked rank + num tiling factors), got ")
5167 << packedRank << " != " << expectedPackedRank;
5168 }
5169
5170 // Verify result shape is greater than the minimum expected
5171 // by the pack operation, and that the output shape
5172 // represents full tiles.
5173 SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
5174 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5175 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5176 if (!llvm::all_of(
5177 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5178 mixedTiles),
5179 [](std::tuple<int64_t, OpFoldResult> it) {
5180 int64_t shape = std::get<0>(it);
5181 if (Attribute attr =
5182 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5183 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5184 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5185 return shape == staticTileSize;
5186 }
5187 return ShapedType::isDynamic(shape);
5188 })) {
5189 return op->emitError("mismatch in inner tile sizes specified and shaped of "
5190 "tiled dimension in the packed type");
5191 }
5192 if (failed(
5193 verifyCompatibleShape(expectedPackedShape, packedType.getShape()))) {
5194 auto elementType = unpackedType.getElementType();
5195 Type expectedType, actualType;
5196 if (packOrUnPack.hasPureTensorSemantics()) {
5197 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5198 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5199 } else {
5200 expectedType = MemRefType::get(expectedPackedShape, elementType);
5201 actualType = MemRefType::get(packedType.getShape(), elementType);
5202 }
5203 return op->emitError("expected ")
5204 << expectedType << " for the packed domain value, got "
5205 << actualType;
5206 }
5207 return success();
5208}
5209
5210namespace {
5211/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5212/// various permutations to the op.
5213// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5214// these. These may or may not become true foldings / canonicalizations
5215// depending on how aggressive we want to be in automatically folding
5216// transposes.
5217struct PackOrUnPackTransposeResult {
5218 SmallVector<int64_t> innerDimsPos;
5219 SmallVector<OpFoldResult> innerTiles;
5220 SmallVector<int64_t> outerDimsPerm;
5221};
5222} // namespace
5223
5224template <typename OpTy>
5225static PackOrUnPackTransposeResult
5227 ArrayRef<int64_t> innerPermutation,
5228 ArrayRef<int64_t> outerPermutation) {
5229 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5230 "applies to only pack or unpack operations");
5231 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5232 "some permutation must be non-empty");
5233 PackOrUnPackTransposeResult metadata;
5234 metadata.innerDimsPos =
5235 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5236 metadata.innerTiles =
5237 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5238 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5239 ? packOrUnPackOp.getSourceRank()
5240 : packOrUnPackOp.getDestRank();
5241 metadata.outerDimsPerm =
5242 packOrUnPackOp.getOuterDimsPerm().empty()
5243 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5244 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5245 if (!innerPermutation.empty()) {
5246 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5247 isPermutationVector(innerPermutation) &&
5248 "invalid inner permutation");
5249 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5250 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5251 }
5252 if (!outerPermutation.empty()) {
5253 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5254 isPermutationVector(outerPermutation) &&
5255 "invalid outer permutation");
5256 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5257 }
5258 return metadata;
5259}
5260
5261//===----------------------------------------------------------------------===//
5262// PackOp
5263//===----------------------------------------------------------------------===//
5264
5265void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5266 if (!getResults().empty())
5267 setNameFn(getResult(), "pack");
5268}
5269
5270ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
5271 OpAsmParser::UnresolvedOperand source, dest;
5274 SmallVector<Type> paddingValueType;
5275 SmallVector<int64_t> staticTiles;
5276 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
5277 Type sourceType, destType, resultType;
5278
5279 if (parser.parseOperand(source))
5280 return failure();
5281
5282 if (succeeded(parser.parseOptionalKeyword("padding_value"))) {
5283 if (parser.parseLParen() ||
5284 parser.parseOperandList(paddingValue, /*requiredOperandCount=*/1) ||
5285 parser.parseColon() || parser.parseTypeList(paddingValueType) ||
5286 parser.parseRParen())
5287 return failure();
5288 }
5289
5290 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
5291 if (parser.parseEqual())
5292 return failure();
5293
5294 SmallVector<int64_t> outerDimsPermVec;
5296 int64_t value;
5297 if (parser.parseInteger(value))
5298 return failure();
5299 outerDimsPermVec.push_back(value);
5300 return success();
5301 }))
5302 return failure();
5303 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
5304 }
5305
5306 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
5307 return failure();
5308
5309 SmallVector<int64_t> innerDimsPosVec;
5311 int64_t value;
5312 if (parser.parseInteger(value))
5313 return failure();
5314 innerDimsPosVec.push_back(value);
5315 return success();
5316 }))
5317 return failure();
5318 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
5319
5320 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
5321 return failure();
5322
5323 DenseI64ArrayAttr staticTilesAttr;
5324 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
5325 return failure();
5326 for (auto val : staticTilesAttr.asArrayRef())
5327 staticTiles.push_back(val);
5328
5329 if (parser.parseKeyword("into") || parser.parseOperand(dest))
5330 return failure();
5331
5332 if (parser.parseOptionalAttrDict(result.attributes))
5333 return failure();
5334
5335 if (parser.parseColon() || parser.parseType(sourceType))
5336 return failure();
5337
5338 bool hasArrow = succeeded(parser.parseOptionalArrow());
5339 if (hasArrow) {
5340 if (parser.parseType(destType))
5341 return failure();
5342 }
5343
5344 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5345 if (!hasArrow) {
5346 return parser.emitError(parser.getCurrentLocation(),
5347 "pack/unpack requires '->' and destination type");
5348 }
5349
5350 if (!isMemRef)
5351 resultType = destType;
5352
5353 if (parser.resolveOperand(source, sourceType, result.operands) ||
5354 parser.resolveOperand(dest, destType, result.operands))
5355 return failure();
5356
5357 if (!paddingValue.empty() &&
5358 parser.resolveOperands(paddingValue, paddingValueType[0],
5359 result.operands))
5360 return failure();
5361
5362 if (!dynamicTiles.empty() &&
5363 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
5364 result.operands))
5365 return failure();
5366
5367 result.addAttribute("static_inner_tiles",
5368 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
5369 result.addAttribute("inner_dims_pos", innerDimsPos);
5370 if (outerDimsPerm)
5371 result.addAttribute("outer_dims_perm", outerDimsPerm);
5372
5373 SmallVector<int32_t> segmentSizes = {
5374 1, 1, static_cast<int32_t>(paddingValue.size()),
5375 static_cast<int32_t>(dynamicTiles.size())};
5376 result.addAttribute("operandSegmentSizes",
5377 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
5378
5379 if (!isMemRef)
5380 result.addTypes(resultType);
5381
5382 return success();
5383}
5384
5385void PackOp::print(OpAsmPrinter &p) {
5386 p << " " << getSource();
5387
5388 if (getPaddingValue()) {
5389 p << " padding_value(" << getPaddingValue() << " : "
5390 << getPaddingValue().getType() << ")";
5391 }
5392
5393 if (!getOuterDimsPerm().empty()) {
5394 p << " outer_dims_perm = [";
5395 llvm::interleaveComma(getOuterDimsPerm(), p);
5396 p << "]";
5397 }
5398
5399 p << " inner_dims_pos = [";
5400 llvm::interleaveComma(getInnerDimsPos(), p);
5401 p << "]";
5402
5403 p << " inner_tiles = ";
5404 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
5405
5406 p << " into " << getDest();
5407
5408 p.printOptionalAttrDict((*this)->getAttrs(),
5409 {"static_inner_tiles", "inner_dims_pos",
5410 "outer_dims_perm", "operandSegmentSizes"});
5411
5412 p << " : " << getSource().getType();
5413 p << " -> " << getDest().getType();
5414}
5415
5416void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5417 Value dest, ArrayRef<int64_t> innerDimsPos,
5418 ArrayRef<OpFoldResult> innerTiles,
5419 std::optional<Value> paddingValue,
5420 ArrayRef<int64_t> outerDimsPerm) {
5421 assert(innerDimsPos.size() == innerTiles.size() &&
5422 "number of tile sizes specified must match the specified number of "
5423 "original dimensions to be tiled");
5424 SmallVector<int64_t> staticTileSizes;
5425 SmallVector<Value> dynamicTileSizes;
5426 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5427 build(builder, state, dest.getType(), source, dest,
5428 paddingValue ? *paddingValue : nullptr,
5429 outerDimsPerm.empty() ? nullptr
5430 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5431 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5432 builder.getDenseI64ArrayAttr(staticTileSizes));
5433}
5434
5435LogicalResult
5436PackOp::reifyResultShapes(OpBuilder &builder,
5437 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5438 if (!hasPureTensorSemantics())
5439 return failure();
5440 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5441}
5442
5443DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5444 return getDimAndTileMappingImpl(*this);
5445}
5446
5447SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5448 return getMixedTilesImpl(*this);
5449}
5450
5451SmallVector<int64_t> PackOp::getStaticTiles() {
5452 return getStaticTilesImpl(*this);
5453}
5454
5455ArrayRef<int64_t> PackOp::getAllOuterDims() {
5456 ShapedType inputType = getSourceType();
5457 int64_t inputRank = inputType.getRank();
5458 return getDestType().getShape().take_front(inputRank);
5459}
5460
5461SmallVector<int64_t> PackOp::getTiledOuterDims() {
5462 auto innerDimsPos = getInnerDimsPos();
5463 SmallVector<int64_t> outerDims(getAllOuterDims());
5465
5466 // Recover the original order of the outer dims.
5467 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5468 invertPermutationVector(outerDimPermInv);
5469 if (!outerDimPermInv.empty())
5470 applyPermutationToVector(outerDims, outerDimPermInv);
5471
5472 // Collect the outer dims corresponding to the tilled inner dims.
5473 for (auto index : innerDimsPos)
5474 res.push_back(outerDims[index]);
5475
5476 return res;
5477}
5478
5479bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5480 ArrayRef<int64_t> innerDimsPos,
5481 ArrayRef<int64_t> outputShape,
5482 ArrayRef<int64_t> outerDimsPerm,
5483 ArrayRef<OpFoldResult> innerTiles) {
5484 SmallVector<int64_t> outputTileSizes(
5485 outputShape.take_front(inputShape.size()));
5486 if (!outerDimsPerm.empty()) {
5487 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5488 "expected output and outer_dims_perm to have same size");
5489 applyPermutationToVector(outputTileSizes,
5490 invertPermutationVector(outerDimsPerm));
5491 }
5492 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5493 if (ShapedType::isDynamic(inputShape[pos]))
5494 continue;
5495 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5496
5497 if (!constantTile) {
5498 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5499 (inputShape[pos] % outputTileSizes[pos] != 0))
5500 return true;
5501 } else if (inputShape[pos] % (*constantTile) != 0) {
5502 return true;
5503 }
5504 }
5505 return false;
5506}
5507
5508bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5509 ArrayRef<int64_t> innerDimsPos,
5510 ArrayRef<int64_t> outputShape,
5511 ArrayRef<int64_t> outerDimsPerm,
5512 ArrayRef<OpFoldResult> innerTiles) {
5513 SmallVector<int64_t> outputTileSizes(
5514 outputShape.take_front(inputShape.size()));
5515 if (!outerDimsPerm.empty()) {
5516 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5517 "expected output and outer_dims_perm to have same size");
5518 applyPermutationToVector(outputTileSizes,
5519 invertPermutationVector(outerDimsPerm));
5520 }
5521 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5522 if (ShapedType::isDynamic(inputShape[pos]) ||
5523 ShapedType::isDynamic(outputTileSizes[pos]))
5524 return true;
5525 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5526 if (!constantTile)
5527 return true;
5528 if (inputShape[pos] % (*constantTile) != 0)
5529 return true;
5530 }
5531 return false;
5532}
5533
5534LogicalResult PackOp::verify() {
5536 return failure();
5537
5538 // Verify padding value, and bail out if the tile does not divide the
5539 // dimension fully. In the case of dynamic tile factors or dimensions, having
5540 // a partial tile is undefined behavior.
5541 auto paddingValue = getPaddingValue();
5542 if (paddingValue &&
5543 paddingValue.getType() != getSourceType().getElementType()) {
5544 return emitOpError("expected padding_value has ")
5545 << getSourceType().getElementType()
5546 << " but got: " << paddingValue.getType();
5547 }
5548
5549 if (!paddingValue &&
5550 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5551 getDestType().getShape(), getOuterDimsPerm(),
5552 getMixedTiles())) {
5553 return emitOpError(
5554 "invalid tile factor or output size provided. Only full tiles are "
5555 "supported when padding_value is not set");
5556 }
5557 return success();
5558}
5559
5560/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5561/// Value's to kDynamic, even if they are arith.constant values.
5565 for (auto o : ofrs) {
5566 // Have to do this first, as getConstantIntValue special-cases constants.
5567 if (llvm::dyn_cast_if_present<Value>(o))
5568 result.push_back(ShapedType::kDynamic);
5569 else
5570 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5571 }
5572 return result;
5573}
5574
5575SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
5576 ArrayRef<int64_t> innerTileSizes,
5577 ArrayRef<int64_t> innerDimsPos,
5578 ArrayRef<int64_t> outerDimsPerm) {
5579 SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
5580 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5581 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5582 continue;
5583 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5584 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5585 continue;
5586 }
5587 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5588 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5589 }
5590
5591 // Swap tile loops if outer_dims_perm is available.
5592 if (!outerDimsPerm.empty())
5593 applyPermutationToVector(resultShape, outerDimsPerm);
5594
5595 // Append the inner tile dimensions.
5596 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5597 return resultShape;
5598}
5599
5600SmallVector<OpFoldResult> PackOp::getResultShape(
5601 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5602 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5603 ArrayRef<int64_t> outerDimsPerm) {
5604 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5605
5606 AffineExpr s0, s1;
5607 bindSymbols(builder.getContext(), s0, s1);
5608 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5609 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5610 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5611 builder, loc, ceilDivExpr,
5612 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5613 }
5614 if (!outerDimsPerm.empty())
5615 applyPermutationToVector(resultDims, outerDimsPerm);
5616 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5617
5618 SmallVector<int64_t> resultTypeShape =
5619 inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
5620 asShapeWithAnyValueAsDynamic(innerTileSizes),
5621 innerDimsPos, outerDimsPerm);
5622
5623 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5624 // result type shape says it's a dynamic dim. This is needed as callers may
5625 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5626 // dynamic dims returned by that.
5627 for (unsigned i = 0; i < resultDims.size(); ++i) {
5628 if (ShapedType::isStatic(resultTypeShape[i]))
5629 continue;
5630 resultDims[i] =
5631 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5632 }
5633
5634 return resultDims;
5635}
5636
5637RankedTensorType PackOp::inferPackedTensorType(
5638 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5639 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5640 SmallVector<int64_t> resultShape = inferPackedShape(
5641 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5642 return RankedTensorType::get(resultShape, sourceType.getElementType());
5643}
5644
5645MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5646 ArrayRef<int64_t> innerTileSizes,
5647 ArrayRef<int64_t> innerDimsPos,
5648 ArrayRef<int64_t> outerDimsPerm) {
5649 SmallVector<int64_t> resultShape = inferPackedShape(
5650 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5651 return MemRefType::get(resultShape, sourceType.getElementType());
5652}
5653
5654Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5655 ArrayRef<OpFoldResult> innerTileSizes,
5656 ArrayRef<int64_t> innerDimsPos,
5657 ArrayRef<int64_t> outerDimsPerm) {
5658 AffineExpr dim0, dim1;
5659 bindDims(b.getContext(), dim0, dim1);
5660 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5661 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5662 {v1, v2});
5663 };
5664
5665 SmallVector<OpFoldResult> mixedSizes;
5666 for (auto [index, value] : llvm::enumerate(
5667 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5668 if (ShapedType::isDynamic(value))
5669 mixedSizes.push_back(
5670 tensor::DimOp::create(b, loc, source, index).getResult());
5671 else
5672 mixedSizes.push_back(b.getIndexAttr(value));
5673 }
5674 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5675 int64_t dimPos = std::get<0>(it);
5676 OpFoldResult tileSize = std::get<1>(it);
5677 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5678 }
5679 if (!outerDimsPerm.empty())
5680 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5681
5682 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5683 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5684 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5685}
5686
5687PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5688 ArrayRef<int64_t> innerPermutation,
5689 ArrayRef<int64_t> outerPermutation) {
5690 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5691 *this, innerPermutation, outerPermutation);
5692 Value transposedDest =
5693 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5694 metadata.innerDimsPos, metadata.outerDimsPerm);
5695 return PackOp::create(b, loc, getSource(), transposedDest,
5696 metadata.innerDimsPos, metadata.innerTiles,
5697 getPaddingValue(), metadata.outerDimsPerm);
5698}
5699
5700template <typename OpTy>
5703 &effects) {
5704 // No memory effects for pure tensor semantics
5705 if (op.hasPureTensorSemantics())
5706 return;
5707
5708 for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5709 if (!llvm::isa<MemRefType>(opOperand.get().getType()))
5710 continue;
5711
5712 if (&opOperand == &op.getSourceMutable()) {
5713 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5714 /*effectOnFullRegion=*/true,
5716 } else if (&opOperand == &op.getDestMutable()) {
5717 effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
5718 /*effectOnFullRegion=*/true,
5720 effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
5721 /*effectOnFullRegion=*/true,
5723 }
5724 }
5725}
5726
5727void PackOp::getEffects(
5729 &effects) {
5730 getPackUnPackEffectsImpl(*this, effects);
5731}
5732
5733void UnPackOp::getEffects(
5735 &effects) {
5736 getPackUnPackEffectsImpl(*this, effects);
5737}
5738
5739/// Returns true if the tiles and the tiled dims are constant.
5740template <typename OpTy>
5742 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5743 "applies to only pack or unpack operations");
5744 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5745 ? op.getDestType()
5746 : op.getSourceType();
5747 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5748 for (auto [dimDest, tile] : llvm::zip(
5749 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5750 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5751 if (!constTileSize || ShapedType::isDynamic(dimDest))
5752 return false;
5753 }
5754 return true;
5755}
5756
5757Speculation::Speculatability PackOp::getSpeculatability() {
5758 if (!hasPureTensorSemantics())
5760 if (getPaddingValue())
5762
5763 // The verifier rejects already operations if we can statically prove that the
5764 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5765 // to have constant tiles and tiled inner dimensions.
5768
5770}
5771
5772// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5773// dimensions for pack and unpack.
5774static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5775 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5776 return false;
5777 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5778 return true;
5779 // Outer dims permutation is optional.
5780 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5781 // identity permutation.
5782 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5783 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5784}
5785
5786// Return true if pack and unpack have the same tiles.
5787// Same SSA values or same integer constants.
5788static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5789 auto packTiles = packOp.getMixedTiles();
5790 auto unPackTiles = unPackOp.getMixedTiles();
5791 if (packTiles.size() != unPackTiles.size())
5792 return false;
5793 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5794 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5795 return false;
5796 }
5797 return true;
5798}
5799
5800/// Returns true if the pack op does not need a padding value.
5801static bool paddingIsNotNeeded(PackOp op) {
5802 auto srcType = op.getSourceType();
5803 if (llvm::any_of(op.getInnerDimsPos(),
5804 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5805 return false;
5806 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5807 return false;
5808 return !PackOp::requirePaddingValue(
5809 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5810 op.getOuterDimsPerm(), op.getMixedTiles());
5811}
5812
5813/// Returns true if the `srcShape` or `destShape` is different from the one in
5814/// `packOp` and populates each with the inferred static shape.
5815static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5816 SmallVectorImpl<int64_t> &destShape) {
5817 bool changeNeeded = false;
5818 srcShape.assign(packOp.getSourceType().getShape().begin(),
5819 packOp.getSourceType().getShape().end());
5820 destShape.assign(packOp.getDestType().getShape().begin(),
5821 packOp.getDestType().getShape().end());
5822 llvm::SmallSetVector<int64_t, 4> innerDims;
5823 innerDims.insert_range(packOp.getInnerDimsPos());
5824 SmallVector<int64_t> inverseOuterDimsPerm;
5825 if (!packOp.getOuterDimsPerm().empty())
5826 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5827 int srcRank = packOp.getSourceRank();
5828 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5829 if (innerDims.contains(i))
5830 continue;
5831 int64_t srcPos = i;
5832 int64_t destPos = i;
5833 if (!inverseOuterDimsPerm.empty())
5834 destPos = inverseOuterDimsPerm[srcPos];
5835 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5836 ShapedType::isDynamic(destShape[destPos])) {
5837 continue;
5838 }
5839 int64_t size = srcShape[srcPos];
5840 if (ShapedType::isDynamic(size))
5841 size = destShape[destPos];
5842 srcShape[srcPos] = size;
5843 destShape[destPos] = size;
5844 changeNeeded = true;
5845 }
5846 return changeNeeded;
5847}
5848
5849LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5850 // TODO: Support Memref PackOp. Temporarily return failure.
5851 if (!packOp.hasPureTensorSemantics())
5852 return failure();
5853
5854 // Fold an pack(unpack(x)) to x.
5855 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5856 if (unPackOp.getSourceType() == packOp.getDestType() &&
5857 !packOp.getPaddingValue() &&
5858 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5859 haveSameTiles(packOp, unPackOp)) {
5860 rewriter.replaceOp(packOp, unPackOp.getSource());
5861 return success();
5862 }
5863 }
5864
5865 // Fold optional PaddingValue operand away if padding is not needed.
5866 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5867 rewriter.startOpModification(packOp);
5868 packOp.getPaddingValueMutable().clear();
5869 rewriter.finalizeOpModification(packOp);
5870 return success();
5871 }
5872
5873 // Insert tensor.cast ops if static shape inference is available..
5874 SmallVector<int64_t> srcShape, destShape;
5875 if (inferStaticShape(packOp, srcShape, destShape)) {
5876 Location loc = packOp.getLoc();
5877 Value source = packOp.getSource();
5878 if (srcShape != packOp.getSourceType().getShape()) {
5879 auto newSrcType = packOp.getSourceType().clone(srcShape);
5880 source =
5881 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5882 }
5883 Value dest = packOp.getDest();
5884 ShapedType originalResultType = packOp.getDestType();
5885 bool needUpdateDestType = (destShape != originalResultType.getShape());
5886 if (needUpdateDestType) {
5887 auto newDestType = packOp.getDestType().clone(destShape);
5888 dest =
5889 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5890 }
5891 rewriter.modifyOpInPlace(packOp, [&] {
5892 packOp.getSourceMutable().assign(source);
5893 packOp.getDestMutable().assign(dest);
5894 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5895 });
5896 // Insert a cast if needed
5897 if (needUpdateDestType) {
5898 rewriter.setInsertionPointAfter(packOp);
5899 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5900 packOp.getResult());
5901 rewriter.replaceAllUsesExcept(packOp.getResult(), castOp, castOp);
5902 }
5903 return success();
5904 }
5905
5906 return failure();
5907}
5908
5909template <typename PackOrUnpackOp>
5910static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
5911 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5912 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5913 "Function meant for pack/unpack");
5914 // This is a pad if packing only adds ones and we don't transpose dimensions.
5915
5916 // Check that we are not transposing any dimensions.
5917 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5918 int64_t numPackedDims = innerDimsPos.size();
5919 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5920 if (orderedDims != innerDimsPos) {
5921 // Dimensions don't happen in order.
5922 return false;
5923 }
5924
5925 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5926 int64_t packedRank = packedTensorType.getRank();
5927 // At this point we know that we are taking numPackedDims outer
5928 // dimensions and pushing them all the way as the inner most dimensions.
5929 // What's left on the outer most dimensions is, in this order:
5930 // - the factor of the packed dimensions, then
5931 // - the untouched dimensions
5932 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5933 // if all the dimensions that bubble outerward are ones.
5934 // Therefore check that all the dimensions but the numPackedDims inner most
5935 // ones are ones.
5936 return llvm::all_of(
5937 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5938 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5939}
5940
5941bool PackOp::isLikePad() {
5942 auto packedTensorType =
5943 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5944 return isLikePadUnPad(*this, packedTensorType);
5945}
5946
5947::mlir::LogicalResult
5948PackOp::fold(FoldAdaptor adaptor,
5950 if (!hasPureTensorSemantics())
5951 return failure();
5952 std::optional<Attribute> paddingValue;
5953 if (auto pad = adaptor.getPaddingValue())
5954 paddingValue = pad;
5955 if (OpFoldResult reshapedSource = reshapeConstantSource(
5956 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5957 cast<TensorType>(getDestType()), paddingValue)) {
5958 results.push_back(reshapedSource);
5959 return success();
5960 }
5961 return failure();
5962}
5963
5964/// Folds a tensor.cast op into a consuming PackOp op if the
5965/// `tensor.cast` has source that is more static than the consuming op.
5966///
5967/// Example:
5968/// ```mlir
5969/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5970/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5971/// ```
5972///
5973/// folds into:
5974///
5975/// ```mlir
5976/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5977/// ```
5980
5981 LogicalResult matchAndRewrite(PackOp op,
5982 PatternRewriter &rewriter) const override {
5983 // TODO: Support Memref PackOp. Temporarily return failure.
5984 if (!op.hasPureTensorSemantics())
5985 return failure();
5986
5988 return failure();
5989
5990 SmallVector<Type> newResultTypes(op->getResultTypes());
5991 SmallVector<Value> newOperands =
5993
5994 // Get the updated mixed-tile-sizes attribute.
5995 SmallVector<OpFoldResult> newMixedTileSizes =
5996 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5997
5998 // Clone op.
5999 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6000 // this point. However, in practice, we use them for things that we'd like
6001 // to preserve. Implement a better abstraction.
6002 PackOp newOp =
6003 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6004 op.getInnerDimsPos(), newMixedTileSizes,
6005 op.getPaddingValue(), op.getOuterDimsPerm());
6006 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6007
6008 // Replace op.
6009 Value oldResult = op.getResult();
6010 Value newResult = newOp.getResult();
6012 (newResult.getType() != oldResult.getType())
6013 ? tensor::CastOp::create(rewriter, op->getLoc(),
6014 oldResult.getType(), newResult)
6015 : newResult;
6016
6017 rewriter.replaceOp(op, {replacement});
6018
6019 return success();
6020 }
6021};
6022
6023//===----------------------------------------------------------------------===//
6024// UnPackOp
6025//===----------------------------------------------------------------------===//
6026
6027void UnPackOp::getAsmResultNames(
6028 function_ref<void(Value, StringRef)> setNameFn) {
6029 if (!getResults().empty())
6030 setNameFn(getResult(), "unpack");
6031}
6032
6033// Custom parser for UnPackOp that handles the memref/tensor case distinction
6034ParseResult UnPackOp::parse(OpAsmParser &parser, OperationState &result) {
6035 OpAsmParser::UnresolvedOperand source, dest;
6037 SmallVector<int64_t> staticTiles;
6038 DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
6039 Type sourceType, destType, resultType;
6040
6041 if (parser.parseOperand(source))
6042 return failure();
6043
6044 if (succeeded(parser.parseOptionalKeyword("outer_dims_perm"))) {
6045 if (parser.parseEqual())
6046 return failure();
6047
6048 SmallVector<int64_t> outerDimsPermVec;
6050 int64_t value;
6051 if (parser.parseInteger(value))
6052 return failure();
6053 outerDimsPermVec.push_back(value);
6054 return success();
6055 }))
6056 return failure();
6057 outerDimsPerm = parser.getBuilder().getDenseI64ArrayAttr(outerDimsPermVec);
6058 }
6059
6060 if (parser.parseKeyword("inner_dims_pos") || parser.parseEqual())
6061 return failure();
6062
6063 SmallVector<int64_t> innerDimsPosVec;
6065 int64_t value;
6066 if (parser.parseInteger(value))
6067 return failure();
6068 innerDimsPosVec.push_back(value);
6069 return success();
6070 }))
6071 return failure();
6072 innerDimsPos = parser.getBuilder().getDenseI64ArrayAttr(innerDimsPosVec);
6073
6074 if (parser.parseKeyword("inner_tiles") || parser.parseEqual())
6075 return failure();
6076
6077 DenseI64ArrayAttr staticTilesAttr;
6078 if (parseDynamicIndexList(parser, dynamicTiles, staticTilesAttr))
6079 return failure();
6080 for (auto val : staticTilesAttr.asArrayRef())
6081 staticTiles.push_back(val);
6082
6083 if (parser.parseKeyword("into") || parser.parseOperand(dest))
6084 return failure();
6085
6086 if (parser.parseOptionalAttrDict(result.attributes))
6087 return failure();
6088
6089 if (parser.parseColon() || parser.parseType(sourceType))
6090 return failure();
6091
6092 bool hasArrow = succeeded(parser.parseOptionalArrow());
6093 if (hasArrow) {
6094 if (parser.parseType(destType))
6095 return failure();
6096 }
6097
6098 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6099 if (!hasArrow) {
6100 return parser.emitError(parser.getCurrentLocation(),
6101 "pack/unpack requires '->' and destination type");
6102 }
6103
6104 if (!isMemRef)
6105 resultType = destType;
6106
6107 if (parser.resolveOperand(source, sourceType, result.operands) ||
6108 parser.resolveOperand(dest, destType, result.operands))
6109 return failure();
6110
6111 if (!dynamicTiles.empty() &&
6112 parser.resolveOperands(dynamicTiles, parser.getBuilder().getIndexType(),
6113 result.operands))
6114 return failure();
6115
6116 result.addAttribute("static_inner_tiles",
6117 parser.getBuilder().getDenseI64ArrayAttr(staticTiles));
6118 result.addAttribute("inner_dims_pos", innerDimsPos);
6119 if (outerDimsPerm)
6120 result.addAttribute("outer_dims_perm", outerDimsPerm);
6121
6122 SmallVector<int32_t> segmentSizes = {
6123 1, 1, 0, static_cast<int32_t>(dynamicTiles.size())};
6124 result.addAttribute("operandSegmentSizes",
6125 parser.getBuilder().getDenseI32ArrayAttr(segmentSizes));
6126
6127 if (!isMemRef)
6128 result.addTypes(resultType);
6129
6130 return success();
6131}
6132
6133void UnPackOp::print(OpAsmPrinter &p) {
6134 p << " " << getSource();
6135
6136 if (!getOuterDimsPerm().empty()) {
6137 p << " outer_dims_perm = [";
6138 llvm::interleaveComma(getOuterDimsPerm(), p);
6139 p << "]";
6140 }
6141
6142 p << " inner_dims_pos = [";
6143 llvm::interleaveComma(getInnerDimsPos(), p);
6144 p << "]";
6145
6146 p << " inner_tiles = ";
6147 printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
6148
6149 p << " into " << getDest();
6150
6151 p.printOptionalAttrDict((*this)->getAttrs(),
6152 {"static_inner_tiles", "inner_dims_pos",
6153 "outer_dims_perm", "operandSegmentSizes"});
6154
6155 p << " : " << getSource().getType();
6156 p << " -> " << getDest().getType();
6157}
6158
6159LogicalResult
6160UnPackOp::reifyResultShapes(OpBuilder &builder,
6161 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
6162 if (!hasPureTensorSemantics())
6163 return failure();
6164 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
6165}
6166
6167DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
6168 return getDimAndTileMappingImpl(*this);
6169}
6170
6171SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6172 return getMixedTilesImpl(*this);
6173}
6174
6175SmallVector<int64_t> UnPackOp::getStaticTiles() {
6176 return getStaticTilesImpl(*this);
6177}
6178
6179ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6180 ShapedType destType = getDestType();
6181 int64_t destRank = destType.getRank();
6182 return getSourceType().getShape().take_front(destRank);
6183}
6184
6185SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6186 auto innerDimsPos = getInnerDimsPos();
6187 SmallVector<int64_t> outerDims(getAllOuterDims());
6189
6190 // Recover the original order of the outer dims.
6191 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6192 invertPermutationVector(outerDimPermInv);
6193 if (!outerDimPermInv.empty())
6194 applyPermutationToVector(outerDims, outerDimPermInv);
6195
6196 // Collect the outer dims corresponding to the tilled inner dims.
6197 for (auto index : innerDimsPos)
6198 res.push_back(outerDims[index]);
6199
6200 return res;
6201}
6202
6203LogicalResult UnPackOp::verify() {
6204 return commonVerifierPackAndUnPackOp(*this);
6205}
6206
6207Speculation::Speculatability UnPackOp::getSpeculatability() {
6208 if (!hasPureTensorSemantics())
6210 // See PackOp::getSpeculatability.
6213
6215}
6216
6217void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6218 Value dest, ArrayRef<int64_t> innerDimsPos,
6219 ArrayRef<OpFoldResult> innerTiles,
6220 ArrayRef<int64_t> outerDimsPerm) {
6221 assert(innerDimsPos.size() == innerTiles.size() &&
6222 "number of tile sizes specified must match the specified number of "
6223 "original dimensions to be tiled");
6224 SmallVector<int64_t> staticTileSizes;
6225 SmallVector<Value> dynamicTileSizes;
6226 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
6227 build(builder, state, dest.getType(), source, dest,
6228 outerDimsPerm.empty() ? nullptr
6229 : builder.getDenseI64ArrayAttr(outerDimsPerm),
6230 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
6231 builder.getDenseI64ArrayAttr(staticTileSizes));
6232}
6233
6234Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
6235 Value source,
6236 ArrayRef<OpFoldResult> innerTileSizes,
6237 ArrayRef<int64_t> innerDimsPos,
6238 ArrayRef<int64_t> outerDimsPerm) {
6239 AffineExpr sym0, sym1;
6240 bindSymbols(b.getContext(), sym0, sym1);
6241 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6242 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
6243 };
6244
6245 SmallVector<OpFoldResult> mixedSizes;
6246 auto srcType = llvm::cast<RankedTensorType>(source.getType());
6247 for (auto i :
6248 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6249 if (srcType.isDynamicDim(i))
6250 mixedSizes.push_back(
6251 tensor::DimOp::create(b, loc, source, i).getResult());
6252 else
6253 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
6254 }
6255 if (!outerDimsPerm.empty()) {
6257 mixedSizes, invertPermutationVector(outerDimsPerm));
6258 }
6259
6260 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6261 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6262
6263 auto elemType = srcType.getElementType();
6264 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
6265}
6266
6267UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
6268 Value transposedSource,
6269 ArrayRef<int64_t> innerPermutation,
6270 ArrayRef<int64_t> outerPermutation) {
6271 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
6272 *this, innerPermutation, outerPermutation);
6273 return UnPackOp::create(b, loc, transposedSource, getDest(),
6274 metadata.innerDimsPos, metadata.innerTiles,
6275 metadata.outerDimsPerm);
6276}
6277
6278/// Returns true if the `srcShape` or `destShape` is different from the one in
6279/// `op` and populates each with the inferred static shape.
6280static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
6281 SmallVectorImpl<int64_t> &destShape) {
6282 bool changeNeeded = false;
6283 srcShape.assign(op.getSourceType().getShape().begin(),
6284 op.getSourceType().getShape().end());
6285 destShape.assign(op.getDestType().getShape().begin(),
6286 op.getDestType().getShape().end());
6287 llvm::SmallSetVector<int64_t, 4> innerDims;
6288 innerDims.insert_range(op.getInnerDimsPos());
6289 SmallVector<int64_t> inverseOuterDimsPerm;
6290 if (!op.getOuterDimsPerm().empty())
6291 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
6292 int destRank = op.getDestRank();
6293 for (auto i : llvm::seq<int64_t>(0, destRank)) {
6294 if (innerDims.contains(i))
6295 continue;
6296 int64_t srcPos = i;
6297 int64_t destPos = i;
6298 if (!inverseOuterDimsPerm.empty())
6299 srcPos = inverseOuterDimsPerm[destPos];
6300 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6301 ShapedType::isDynamic(destShape[destPos])) {
6302 continue;
6303 }
6304 int64_t size = srcShape[srcPos];
6305 if (ShapedType::isDynamic(size))
6306 size = destShape[destPos];
6307 srcShape[srcPos] = size;
6308 destShape[destPos] = size;
6309 changeNeeded = true;
6310 }
6311 return changeNeeded;
6312}
6313
6314LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6315 PatternRewriter &rewriter) {
6316 // TODO: Support Memref UnPackOp. Temporarily return failure.
6317 if (!unPackOp.hasPureTensorSemantics())
6318 return failure();
6319
6320 /// unpack(pack(x)) -> x
6321 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6322 if (packOp.getSourceType() != unPackOp.getDestType())
6323 return failure();
6324 if (packOp.getPaddingValue() ||
6325 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
6326 !haveSameTiles(packOp, unPackOp))
6327 return failure();
6328 rewriter.replaceOp(unPackOp, packOp.getSource());
6329 return success();
6330 }
6331 /// unpack(destinationStyleOp(x)) -> unpack(x)
6332 if (auto dstStyleOp =
6333 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6334 auto destValue = cast<OpResult>(unPackOp.getDest());
6335 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6336 rewriter.modifyOpInPlace(unPackOp,
6337 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6338 return success();
6339 }
6340 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
6341 if (unPackOp->hasOneUse()) {
6342 auto extractSliceUser =
6343 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6344 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6345 OpBuilder::InsertionGuard g(rewriter);
6346 rewriter.setInsertionPoint(unPackOp);
6347 auto newDest = tensor::ExtractSliceOp::create(
6348 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6349 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6350 extractSliceUser.getMixedStrides());
6351 rewriter.modifyOpInPlace(unPackOp, [&]() {
6352 unPackOp.setDpsInitOperand(0, newDest);
6353 unPackOp.getResult().setType(newDest.getType());
6354 });
6355 rewriter.replaceOp(extractSliceUser, unPackOp);
6356 return success();
6357 }
6358 }
6359
6360 // Insert tensor.cast ops if static shape inference is available..
6361 SmallVector<int64_t> srcShape, destShape;
6362 if (inferStaticShape(unPackOp, srcShape, destShape)) {
6363 Location loc = unPackOp.getLoc();
6364 Value source = unPackOp.getSource();
6365 if (srcShape != unPackOp.getSourceType().getShape()) {
6366 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6367 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6368 unPackOp.getSource());
6369 }
6370 Value dest = unPackOp.getDest();
6371 if (destShape != unPackOp.getDestType().getShape()) {
6372 auto newDestType = unPackOp.getDestType().clone(destShape);
6373 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6374 unPackOp.getDest());
6375 }
6376 UnPackOp newOp = UnPackOp::create(
6377 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6378 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6379 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6380 unPackOp, unPackOp.getResult().getType(), newOp.getResult());
6381 return success();
6382 }
6383
6384 return failure();
6385}
6386
6387bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6388 // Rank-reduced folding is not supported.
6389 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6390 return false;
6391 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6392 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6393 return false;
6394 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6395 SmallVector<int64_t> outerShapeWithoutTranspose =
6397 for (auto [pos, tileSize] :
6398 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6399 if (unpackedTypeAfterFold.isDynamicDim(pos))
6400 return false;
6401 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6402 return false;
6403 if (ShapedType::isDynamic(tileSize))
6404 return false;
6405 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6406 unpackedTypeAfterFold.getDimSize(pos);
6407 if (paddingSize >= tileSize)
6408 return false;
6409 }
6410 return true;
6411}
6412
6413bool UnPackOp::isLikeUnPad() {
6414 ShapedType packedTensorType = getSourceType();
6415 return isLikePadUnPad(*this, packedTensorType);
6416}
6417
6418::mlir::LogicalResult
6419UnPackOp::fold(FoldAdaptor adaptor,
6420 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6421 // TODO: Support Memref UnPackOp. Temporarily return failure.
6422 if (!hasPureTensorSemantics())
6423 return failure();
6424
6425 if (OpFoldResult reshapedSource = reshapeConstantSource(
6426 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6427 cast<TensorType>(getResult().getType()))) {
6428 results.push_back(reshapedSource);
6429 return success();
6430 }
6431 return failure();
6432}
6433
6434/// Folds a tensor.cast op into a consuming UnPackOp op if the
6435/// `tensor.cast` has source that is more static than the consuming op.
6436///
6437/// Example:
6438/// ```mlir
6439/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6440/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6441/// ```
6442///
6443/// folds into:
6444///
6445/// ```mlir
6446/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6447/// ```
6448struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6449 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6450
6451 LogicalResult matchAndRewrite(UnPackOp op,
6452 PatternRewriter &rewriter) const override {
6453 // TODO: Support Memref UnPackOp. Temporarily return failure.
6454 if (!op.hasPureTensorSemantics())
6455 return failure();
6456
6458 return failure();
6459
6460 SmallVector<Type> newResultTypes(op->getResultTypes());
6461 SmallVector<Value> newOperands =
6463 Value sourceTensor = newOperands[0];
6464
6465 // Get the updated mixed-tile-sizes attribute.
6467 rewriter, sourceTensor.getType(), op.getMixedTiles());
6468
6469 // Clone op.
6470 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6471 // this point. However, in practice, we use them for things that we'd like
6472 // to preserve. Implement a better abstraction.
6473 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6474 newOperands[1], op.getInnerDimsPos(),
6475 newMixedTileSizes, op.getOuterDimsPerm());
6476 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6477
6478 // Replace op.
6479 Value oldResult = op.getResult();
6480 Value newResult = newOp.getResult();
6482 (newResult.getType() != oldResult.getType())
6483 ? tensor::CastOp::create(rewriter, op->getLoc(),
6484 oldResult.getType(), newResult)
6485 : newResult;
6486
6487 rewriter.replaceOp(op, {replacement});
6488
6489 return success();
6490 }
6491};
6492
6493//===----------------------------------------------------------------------===//
6494// BatchReduceMatmulOp
6495//===----------------------------------------------------------------------===//
6496SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6498 utils::IteratorType::reduction, utils::IteratorType::parallel,
6499 utils::IteratorType::parallel, utils::IteratorType::reduction};
6500}
6501
6503BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6504 AffineExpr d0, d1, d2, d3;
6505 SmallVector<AffineMap> indexingMaps;
6506 bindDims(context, d0, d1, d2, d3);
6507 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6508 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6509 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6510 return indexingMaps;
6511}
6512
6513bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6514 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6515 if (!maps)
6516 return false;
6517 if (maps.size() != 3)
6518 return false;
6519 auto positions = getAffineResultPositions(maps);
6520 if (failed(positions))
6521 return false;
6522 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6523 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6524 (*positions)[2] == SmallVector<int64_t>{1, 2};
6525}
6526unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6527
6528std::string BatchReduceMatmulOp::getLibraryCallName() {
6529 return generateLibraryCallName(getOperation());
6530}
6531
6532/// Check if the op has broadcast and/or transpose semantic. Returns true if
6533/// the user defined indexing maps are not equal to default map.
6534bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6535 SmallVector<AffineMap, 3> defaultMaps =
6536 getDefaultIndexingMaps(this->getContext());
6537 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6538 return defaultMaps != explicitMaps;
6539}
6540
6541/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6542/// broadcast map must include K dimension.
6543/// TODO: Strict inclusion of K dimension in the broadcast map is not
6544/// necessary for both input matrices simultaneously. We can relax this
6545/// condition to have K dimension for one input matrix map and infer the K
6546/// dimension for other input matrix map from the one already having K
6547/// dimension.
6548bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6549 bool isLHS) {
6550 assert(bcastMap.getNumResults() < 3 &&
6551 "Expected less than 3 result dim expr.");
6552 bool isValid = false;
6553 enum Indices { batchPos, mPos, nPos, kPos };
6554 if (bcastMap.getNumResults() == 1) {
6555 AffineExpr expr = bcastMap.getResult(0);
6556 isValid = expr.isFunctionOfDim(kPos);
6557 } else if (bcastMap.getNumResults() == 2) {
6558 AffineExpr expr0 = bcastMap.getResult(0);
6559 AffineExpr expr1 = bcastMap.getResult(1);
6560 isValid =
6561 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6562 expr0.isFunctionOfDim(mPos)) &&
6563 expr1.isFunctionOfDim(kPos))
6564 : ((expr0.isFunctionOfDim(batchPos) &&
6565 expr1.isFunctionOfDim(kPos)) ||
6566 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6567 }
6568 return isValid;
6569}
6570
6571void BatchReduceMatmulOp::regionBuilder(
6574 if (emitError && block.getNumArguments() != 3) {
6575 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6576 << block.getNumArguments();
6577 return;
6578 }
6579 assert(block.getNumArguments() == 3 &&
6580 "BatchReduceMatmulOp regionBuilder expects 3 args");
6581 RegionBuilderHelper helper(b, block);
6582 SmallVector<Value> yields;
6583
6584 auto toType = block.getArgument(2).getType();
6585 Value castValA =
6586 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6587 Value castValB =
6588 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6589 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6590 Value addVal =
6591 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6592 yields.push_back(addVal);
6593 helper.yieldOutputs(yields);
6594}
6595
6596ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6598 SmallVector<Attribute, 3> indexingMapsAttr;
6599 Attribute mapAttr;
6600 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6601 if (parser.parseEqual())
6602 return failure();
6603 if (parser.parseLSquare())
6604 return failure();
6605
6606 do {
6607 if (parser.parseAttribute(mapAttr))
6608 return failure();
6609 if (!isa<AffineMapAttr>(mapAttr)) {
6610 return parser.emitError(parser.getCurrentLocation(),
6611 "expected affine map attribute");
6612 }
6613 indexingMapsAttr.push_back(mapAttr);
6614
6615 if (parser.parseOptionalComma())
6616 break;
6617 } while (true);
6618
6619 if (parser.parseRSquare())
6620 return failure();
6621 }
6622 // Initialize indexingMaps, if not supplied explicitly.
6623 if (indexingMapsAttr.empty()) {
6624 indexingMapsAttr = llvm::map_to_vector(
6625 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6626 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6627 }
6628 result.addAttribute("indexing_maps",
6629 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6630 return ::parseNamedStructuredOp(parser, result,
6631 BatchReduceMatmulOp::getNumRegionArgs(),
6632 BatchReduceMatmulOp::getRegionBuilder());
6633}
6634
6635void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6636 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6637 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6638 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6639
6640 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6641 p << " indexing_maps = [";
6642 llvm::interleaveComma(getIndexingMaps(), p,
6643 [&](Attribute attr) { p.printAttribute(attr); });
6644 p << "]";
6645 }
6646
6647 SmallVector<StringRef, 3> elidedAttrs = {
6648 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6649 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6650 elidedAttrs);
6651}
6652
6653/// Verify the user defined indexing maps.
6654LogicalResult BatchReduceMatmulOp::verify() {
6655 // Verification of pure batch_reduce_matmul is handled by
6656 // verifyStructuredOpInterface().
6657 if (!hasUserDefinedMaps())
6658 return success();
6659
6660 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6662 return failure();
6663 }
6664 return success();
6665}
6666LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6668 return memref::foldMemRefCast(*this);
6669}
6670void BatchReduceMatmulOp::getEffects(
6672 &effects) {
6673 if (hasPureTensorSemantics())
6674 return;
6675 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6676}
6677
6678Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6679 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6680}
6681
6682} // namespace linalg
6683} // namespace mlir
6684
6685//===----------------------------------------------------------------------===//
6686// LinalgDialect
6687//===----------------------------------------------------------------------===//
6688
6689void LinalgDialect::getCanonicalizationPatterns(
6690 RewritePatternSet &results) const {
6691 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6692 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6693}
6694
6695Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6696 Attribute value, Type type,
6697 Location loc) {
6698 return arith::ConstantOp::materialize(builder, value, type, loc);
6699}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:58
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
result_iterator result_begin()
Definition Operation.h:413
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_iterator result_end()
Definition Operation.h:414
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:96
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:67
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override